Spaces:
Sleeping
Sleeping
Staticaliza
commited on
Upload 4 files
Browse files- dac/nn/__init__.py +3 -0
- dac/nn/layers.py +33 -0
- dac/nn/loss.py +368 -0
- dac/nn/quantize.py +339 -0
dac/nn/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import layers
|
2 |
+
from . import loss
|
3 |
+
from . import quantize
|
dac/nn/layers.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
# Scripting this brings model speed up 1.4x
|
18 |
+
@torch.jit.script
|
19 |
+
def snake(x, alpha):
|
20 |
+
shape = x.shape
|
21 |
+
x = x.reshape(shape[0], shape[1], -1)
|
22 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
23 |
+
x = x.reshape(shape)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class Snake1d(nn.Module):
|
28 |
+
def __init__(self, channels):
|
29 |
+
super().__init__()
|
30 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return snake(x, self.alpha)
|
dac/nn/loss.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from audiotools import AudioSignal
|
7 |
+
from audiotools import STFTParams
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
class L1Loss(nn.L1Loss):
|
12 |
+
"""L1 Loss between AudioSignals. Defaults
|
13 |
+
to comparing ``audio_data``, but any
|
14 |
+
attribute of an AudioSignal can be used.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
attribute : str, optional
|
19 |
+
Attribute of signal to compare, defaults to ``audio_data``.
|
20 |
+
weight : float, optional
|
21 |
+
Weight of this loss, defaults to 1.0.
|
22 |
+
|
23 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
|
27 |
+
self.attribute = attribute
|
28 |
+
self.weight = weight
|
29 |
+
super().__init__(**kwargs)
|
30 |
+
|
31 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
32 |
+
"""
|
33 |
+
Parameters
|
34 |
+
----------
|
35 |
+
x : AudioSignal
|
36 |
+
Estimate AudioSignal
|
37 |
+
y : AudioSignal
|
38 |
+
Reference AudioSignal
|
39 |
+
|
40 |
+
Returns
|
41 |
+
-------
|
42 |
+
torch.Tensor
|
43 |
+
L1 loss between AudioSignal attributes.
|
44 |
+
"""
|
45 |
+
if isinstance(x, AudioSignal):
|
46 |
+
x = getattr(x, self.attribute)
|
47 |
+
y = getattr(y, self.attribute)
|
48 |
+
return super().forward(x, y)
|
49 |
+
|
50 |
+
|
51 |
+
class SISDRLoss(nn.Module):
|
52 |
+
"""
|
53 |
+
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
|
54 |
+
of estimated and reference audio signals or aligned features.
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
scaling : int, optional
|
59 |
+
Whether to use scale-invariant (True) or
|
60 |
+
signal-to-noise ratio (False), by default True
|
61 |
+
reduction : str, optional
|
62 |
+
How to reduce across the batch (either 'mean',
|
63 |
+
'sum', or none).], by default ' mean'
|
64 |
+
zero_mean : int, optional
|
65 |
+
Zero mean the references and estimates before
|
66 |
+
computing the loss, by default True
|
67 |
+
clip_min : int, optional
|
68 |
+
The minimum possible loss value. Helps network
|
69 |
+
to not focus on making already good examples better, by default None
|
70 |
+
weight : float, optional
|
71 |
+
Weight of this loss, defaults to 1.0.
|
72 |
+
|
73 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
scaling: int = True,
|
79 |
+
reduction: str = "mean",
|
80 |
+
zero_mean: int = True,
|
81 |
+
clip_min: int = None,
|
82 |
+
weight: float = 1.0,
|
83 |
+
):
|
84 |
+
self.scaling = scaling
|
85 |
+
self.reduction = reduction
|
86 |
+
self.zero_mean = zero_mean
|
87 |
+
self.clip_min = clip_min
|
88 |
+
self.weight = weight
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
92 |
+
eps = 1e-8
|
93 |
+
# nb, nc, nt
|
94 |
+
if isinstance(x, AudioSignal):
|
95 |
+
references = x.audio_data
|
96 |
+
estimates = y.audio_data
|
97 |
+
else:
|
98 |
+
references = x
|
99 |
+
estimates = y
|
100 |
+
|
101 |
+
nb = references.shape[0]
|
102 |
+
references = references.reshape(nb, 1, -1).permute(0, 2, 1)
|
103 |
+
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
|
104 |
+
|
105 |
+
# samples now on axis 1
|
106 |
+
if self.zero_mean:
|
107 |
+
mean_reference = references.mean(dim=1, keepdim=True)
|
108 |
+
mean_estimate = estimates.mean(dim=1, keepdim=True)
|
109 |
+
else:
|
110 |
+
mean_reference = 0
|
111 |
+
mean_estimate = 0
|
112 |
+
|
113 |
+
_references = references - mean_reference
|
114 |
+
_estimates = estimates - mean_estimate
|
115 |
+
|
116 |
+
references_projection = (_references**2).sum(dim=-2) + eps
|
117 |
+
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
|
118 |
+
|
119 |
+
scale = (
|
120 |
+
(references_on_estimates / references_projection).unsqueeze(1)
|
121 |
+
if self.scaling
|
122 |
+
else 1
|
123 |
+
)
|
124 |
+
|
125 |
+
e_true = scale * _references
|
126 |
+
e_res = _estimates - e_true
|
127 |
+
|
128 |
+
signal = (e_true**2).sum(dim=1)
|
129 |
+
noise = (e_res**2).sum(dim=1)
|
130 |
+
sdr = -10 * torch.log10(signal / noise + eps)
|
131 |
+
|
132 |
+
if self.clip_min is not None:
|
133 |
+
sdr = torch.clamp(sdr, min=self.clip_min)
|
134 |
+
|
135 |
+
if self.reduction == "mean":
|
136 |
+
sdr = sdr.mean()
|
137 |
+
elif self.reduction == "sum":
|
138 |
+
sdr = sdr.sum()
|
139 |
+
return sdr
|
140 |
+
|
141 |
+
|
142 |
+
class MultiScaleSTFTLoss(nn.Module):
|
143 |
+
"""Computes the multi-scale STFT loss from [1].
|
144 |
+
|
145 |
+
Parameters
|
146 |
+
----------
|
147 |
+
window_lengths : List[int], optional
|
148 |
+
Length of each window of each STFT, by default [2048, 512]
|
149 |
+
loss_fn : typing.Callable, optional
|
150 |
+
How to compare each loss, by default nn.L1Loss()
|
151 |
+
clamp_eps : float, optional
|
152 |
+
Clamp on the log magnitude, below, by default 1e-5
|
153 |
+
mag_weight : float, optional
|
154 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
155 |
+
log_weight : float, optional
|
156 |
+
Weight of log magnitude portion of loss, by default 1.0
|
157 |
+
pow : float, optional
|
158 |
+
Power to raise magnitude to before taking log, by default 2.0
|
159 |
+
weight : float, optional
|
160 |
+
Weight of this loss, by default 1.0
|
161 |
+
match_stride : bool, optional
|
162 |
+
Whether to match the stride of convolutional layers, by default False
|
163 |
+
|
164 |
+
References
|
165 |
+
----------
|
166 |
+
|
167 |
+
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
|
168 |
+
"DDSP: Differentiable Digital Signal Processing."
|
169 |
+
International Conference on Learning Representations. 2019.
|
170 |
+
|
171 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
window_lengths: List[int] = [2048, 512],
|
177 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
178 |
+
clamp_eps: float = 1e-5,
|
179 |
+
mag_weight: float = 1.0,
|
180 |
+
log_weight: float = 1.0,
|
181 |
+
pow: float = 2.0,
|
182 |
+
weight: float = 1.0,
|
183 |
+
match_stride: bool = False,
|
184 |
+
window_type: str = None,
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
self.stft_params = [
|
188 |
+
STFTParams(
|
189 |
+
window_length=w,
|
190 |
+
hop_length=w // 4,
|
191 |
+
match_stride=match_stride,
|
192 |
+
window_type=window_type,
|
193 |
+
)
|
194 |
+
for w in window_lengths
|
195 |
+
]
|
196 |
+
self.loss_fn = loss_fn
|
197 |
+
self.log_weight = log_weight
|
198 |
+
self.mag_weight = mag_weight
|
199 |
+
self.clamp_eps = clamp_eps
|
200 |
+
self.weight = weight
|
201 |
+
self.pow = pow
|
202 |
+
|
203 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
204 |
+
"""Computes multi-scale STFT between an estimate and a reference
|
205 |
+
signal.
|
206 |
+
|
207 |
+
Parameters
|
208 |
+
----------
|
209 |
+
x : AudioSignal
|
210 |
+
Estimate signal
|
211 |
+
y : AudioSignal
|
212 |
+
Reference signal
|
213 |
+
|
214 |
+
Returns
|
215 |
+
-------
|
216 |
+
torch.Tensor
|
217 |
+
Multi-scale STFT loss.
|
218 |
+
"""
|
219 |
+
loss = 0.0
|
220 |
+
for s in self.stft_params:
|
221 |
+
x.stft(s.window_length, s.hop_length, s.window_type)
|
222 |
+
y.stft(s.window_length, s.hop_length, s.window_type)
|
223 |
+
loss += self.log_weight * self.loss_fn(
|
224 |
+
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
225 |
+
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
226 |
+
)
|
227 |
+
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
|
228 |
+
return loss
|
229 |
+
|
230 |
+
|
231 |
+
class MelSpectrogramLoss(nn.Module):
|
232 |
+
"""Compute distance between mel spectrograms. Can be used
|
233 |
+
in a multi-scale way.
|
234 |
+
|
235 |
+
Parameters
|
236 |
+
----------
|
237 |
+
n_mels : List[int]
|
238 |
+
Number of mels per STFT, by default [150, 80],
|
239 |
+
window_lengths : List[int], optional
|
240 |
+
Length of each window of each STFT, by default [2048, 512]
|
241 |
+
loss_fn : typing.Callable, optional
|
242 |
+
How to compare each loss, by default nn.L1Loss()
|
243 |
+
clamp_eps : float, optional
|
244 |
+
Clamp on the log magnitude, below, by default 1e-5
|
245 |
+
mag_weight : float, optional
|
246 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
247 |
+
log_weight : float, optional
|
248 |
+
Weight of log magnitude portion of loss, by default 1.0
|
249 |
+
pow : float, optional
|
250 |
+
Power to raise magnitude to before taking log, by default 2.0
|
251 |
+
weight : float, optional
|
252 |
+
Weight of this loss, by default 1.0
|
253 |
+
match_stride : bool, optional
|
254 |
+
Whether to match the stride of convolutional layers, by default False
|
255 |
+
|
256 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
n_mels: List[int] = [150, 80],
|
262 |
+
window_lengths: List[int] = [2048, 512],
|
263 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
264 |
+
clamp_eps: float = 1e-5,
|
265 |
+
mag_weight: float = 1.0,
|
266 |
+
log_weight: float = 1.0,
|
267 |
+
pow: float = 2.0,
|
268 |
+
weight: float = 1.0,
|
269 |
+
match_stride: bool = False,
|
270 |
+
mel_fmin: List[float] = [0.0, 0.0],
|
271 |
+
mel_fmax: List[float] = [None, None],
|
272 |
+
window_type: str = None,
|
273 |
+
):
|
274 |
+
super().__init__()
|
275 |
+
self.stft_params = [
|
276 |
+
STFTParams(
|
277 |
+
window_length=w,
|
278 |
+
hop_length=w // 4,
|
279 |
+
match_stride=match_stride,
|
280 |
+
window_type=window_type,
|
281 |
+
)
|
282 |
+
for w in window_lengths
|
283 |
+
]
|
284 |
+
self.n_mels = n_mels
|
285 |
+
self.loss_fn = loss_fn
|
286 |
+
self.clamp_eps = clamp_eps
|
287 |
+
self.log_weight = log_weight
|
288 |
+
self.mag_weight = mag_weight
|
289 |
+
self.weight = weight
|
290 |
+
self.mel_fmin = mel_fmin
|
291 |
+
self.mel_fmax = mel_fmax
|
292 |
+
self.pow = pow
|
293 |
+
|
294 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
295 |
+
"""Computes mel loss between an estimate and a reference
|
296 |
+
signal.
|
297 |
+
|
298 |
+
Parameters
|
299 |
+
----------
|
300 |
+
x : AudioSignal
|
301 |
+
Estimate signal
|
302 |
+
y : AudioSignal
|
303 |
+
Reference signal
|
304 |
+
|
305 |
+
Returns
|
306 |
+
-------
|
307 |
+
torch.Tensor
|
308 |
+
Mel loss.
|
309 |
+
"""
|
310 |
+
loss = 0.0
|
311 |
+
for n_mels, fmin, fmax, s in zip(
|
312 |
+
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
313 |
+
):
|
314 |
+
kwargs = {
|
315 |
+
"window_length": s.window_length,
|
316 |
+
"hop_length": s.hop_length,
|
317 |
+
"window_type": s.window_type,
|
318 |
+
}
|
319 |
+
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
320 |
+
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
321 |
+
|
322 |
+
loss += self.log_weight * self.loss_fn(
|
323 |
+
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
324 |
+
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
325 |
+
)
|
326 |
+
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
|
327 |
+
return loss
|
328 |
+
|
329 |
+
|
330 |
+
class GANLoss(nn.Module):
|
331 |
+
"""
|
332 |
+
Computes a discriminator loss, given a discriminator on
|
333 |
+
generated waveforms/spectrograms compared to ground truth
|
334 |
+
waveforms/spectrograms. Computes the loss for both the
|
335 |
+
discriminator and the generator in separate functions.
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(self, discriminator):
|
339 |
+
super().__init__()
|
340 |
+
self.discriminator = discriminator
|
341 |
+
|
342 |
+
def forward(self, fake, real):
|
343 |
+
d_fake = self.discriminator(fake.audio_data)
|
344 |
+
d_real = self.discriminator(real.audio_data)
|
345 |
+
return d_fake, d_real
|
346 |
+
|
347 |
+
def discriminator_loss(self, fake, real):
|
348 |
+
d_fake, d_real = self.forward(fake.clone().detach(), real)
|
349 |
+
|
350 |
+
loss_d = 0
|
351 |
+
for x_fake, x_real in zip(d_fake, d_real):
|
352 |
+
loss_d += torch.mean(x_fake[-1] ** 2)
|
353 |
+
loss_d += torch.mean((1 - x_real[-1]) ** 2)
|
354 |
+
return loss_d
|
355 |
+
|
356 |
+
def generator_loss(self, fake, real):
|
357 |
+
d_fake, d_real = self.forward(fake, real)
|
358 |
+
|
359 |
+
loss_g = 0
|
360 |
+
for x_fake in d_fake:
|
361 |
+
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
|
362 |
+
|
363 |
+
loss_feature = 0
|
364 |
+
|
365 |
+
for i in range(len(d_fake)):
|
366 |
+
for j in range(len(d_fake[i]) - 1):
|
367 |
+
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
|
368 |
+
return loss_g, loss_feature
|
dac/nn/quantize.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
from dac.nn.layers import WNConv1d
|
11 |
+
|
12 |
+
class VectorQuantizeLegacy(nn.Module):
|
13 |
+
"""
|
14 |
+
Implementation of VQ similar to Karpathy's repo:
|
15 |
+
https://github.com/karpathy/deep-vector-quantization
|
16 |
+
removed in-out projection
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, input_dim: int, codebook_size: int):
|
20 |
+
super().__init__()
|
21 |
+
self.codebook_size = codebook_size
|
22 |
+
self.codebook = nn.Embedding(codebook_size, input_dim)
|
23 |
+
|
24 |
+
def forward(self, z, z_mask=None):
|
25 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
26 |
+
the corresponding codebook vectors
|
27 |
+
|
28 |
+
Parameters
|
29 |
+
----------
|
30 |
+
z : Tensor[B x D x T]
|
31 |
+
|
32 |
+
Returns
|
33 |
+
-------
|
34 |
+
Tensor[B x D x T]
|
35 |
+
Quantized continuous representation of input
|
36 |
+
Tensor[1]
|
37 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
38 |
+
entries
|
39 |
+
Tensor[1]
|
40 |
+
Codebook loss to update the codebook
|
41 |
+
Tensor[B x T]
|
42 |
+
Codebook indices (quantized discrete representation of input)
|
43 |
+
Tensor[B x D x T]
|
44 |
+
Projected latents (continuous representation of input before quantization)
|
45 |
+
"""
|
46 |
+
|
47 |
+
z_e = z
|
48 |
+
z_q, indices = self.decode_latents(z)
|
49 |
+
|
50 |
+
if z_mask is not None:
|
51 |
+
commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
|
52 |
+
codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
|
53 |
+
else:
|
54 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach())
|
55 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach())
|
56 |
+
z_q = (
|
57 |
+
z_e + (z_q - z_e).detach()
|
58 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
59 |
+
|
60 |
+
return z_q, indices, z_e, commitment_loss, codebook_loss
|
61 |
+
|
62 |
+
def embed_code(self, embed_id):
|
63 |
+
return F.embedding(embed_id, self.codebook.weight)
|
64 |
+
|
65 |
+
def decode_code(self, embed_id):
|
66 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
67 |
+
|
68 |
+
def decode_latents(self, latents):
|
69 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
70 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
71 |
+
|
72 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
73 |
+
encodings = F.normalize(encodings)
|
74 |
+
codebook = F.normalize(codebook)
|
75 |
+
|
76 |
+
# Compute euclidean distance with codebook
|
77 |
+
dist = (
|
78 |
+
encodings.pow(2).sum(1, keepdim=True)
|
79 |
+
- 2 * encodings @ codebook.t()
|
80 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
81 |
+
)
|
82 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
83 |
+
z_q = self.decode_code(indices)
|
84 |
+
return z_q, indices
|
85 |
+
|
86 |
+
class VectorQuantize(nn.Module):
|
87 |
+
"""
|
88 |
+
Implementation of VQ similar to Karpathy's repo:
|
89 |
+
https://github.com/karpathy/deep-vector-quantization
|
90 |
+
Additionally uses following tricks from Improved VQGAN
|
91 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
92 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
93 |
+
for improved codebook usage
|
94 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
95 |
+
improves training stability
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
99 |
+
super().__init__()
|
100 |
+
self.codebook_size = codebook_size
|
101 |
+
self.codebook_dim = codebook_dim
|
102 |
+
|
103 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
104 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
105 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
106 |
+
|
107 |
+
def forward(self, z, z_mask=None):
|
108 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
109 |
+
the corresponding codebook vectors
|
110 |
+
|
111 |
+
Parameters
|
112 |
+
----------
|
113 |
+
z : Tensor[B x D x T]
|
114 |
+
|
115 |
+
Returns
|
116 |
+
-------
|
117 |
+
Tensor[B x D x T]
|
118 |
+
Quantized continuous representation of input
|
119 |
+
Tensor[1]
|
120 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
121 |
+
entries
|
122 |
+
Tensor[1]
|
123 |
+
Codebook loss to update the codebook
|
124 |
+
Tensor[B x T]
|
125 |
+
Codebook indices (quantized discrete representation of input)
|
126 |
+
Tensor[B x D x T]
|
127 |
+
Projected latents (continuous representation of input before quantization)
|
128 |
+
"""
|
129 |
+
|
130 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
131 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
132 |
+
z_q, indices = self.decode_latents(z_e)
|
133 |
+
|
134 |
+
if z_mask is not None:
|
135 |
+
commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
|
136 |
+
codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
|
137 |
+
else:
|
138 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach())
|
139 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach())
|
140 |
+
|
141 |
+
z_q = (
|
142 |
+
z_e + (z_q - z_e).detach()
|
143 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
144 |
+
|
145 |
+
z_q = self.out_proj(z_q)
|
146 |
+
|
147 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
148 |
+
|
149 |
+
def embed_code(self, embed_id):
|
150 |
+
return F.embedding(embed_id, self.codebook.weight)
|
151 |
+
|
152 |
+
def decode_code(self, embed_id):
|
153 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
154 |
+
|
155 |
+
def decode_latents(self, latents):
|
156 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
157 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
158 |
+
|
159 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
160 |
+
encodings = F.normalize(encodings)
|
161 |
+
codebook = F.normalize(codebook)
|
162 |
+
|
163 |
+
# Compute euclidean distance with codebook
|
164 |
+
dist = (
|
165 |
+
encodings.pow(2).sum(1, keepdim=True)
|
166 |
+
- 2 * encodings @ codebook.t()
|
167 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
168 |
+
)
|
169 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
170 |
+
z_q = self.decode_code(indices)
|
171 |
+
return z_q, indices
|
172 |
+
|
173 |
+
|
174 |
+
class ResidualVectorQuantize(nn.Module):
|
175 |
+
"""
|
176 |
+
Introduced in SoundStream: An end2end neural audio codec
|
177 |
+
https://arxiv.org/abs/2107.03312
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
input_dim: int = 512,
|
183 |
+
n_codebooks: int = 9,
|
184 |
+
codebook_size: int = 1024,
|
185 |
+
codebook_dim: Union[int, list] = 8,
|
186 |
+
quantizer_dropout: float = 0.0,
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
if isinstance(codebook_dim, int):
|
190 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
191 |
+
|
192 |
+
self.n_codebooks = n_codebooks
|
193 |
+
self.codebook_dim = codebook_dim
|
194 |
+
self.codebook_size = codebook_size
|
195 |
+
|
196 |
+
self.quantizers = nn.ModuleList(
|
197 |
+
[
|
198 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
199 |
+
for i in range(n_codebooks)
|
200 |
+
]
|
201 |
+
)
|
202 |
+
self.quantizer_dropout = quantizer_dropout
|
203 |
+
|
204 |
+
def forward(self, z, n_quantizers: int = None):
|
205 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
206 |
+
the corresponding codebook vectors
|
207 |
+
Parameters
|
208 |
+
----------
|
209 |
+
z : Tensor[B x D x T]
|
210 |
+
n_quantizers : int, optional
|
211 |
+
No. of quantizers to use
|
212 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
213 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
214 |
+
when in training mode, and a random number of quantizers is used.
|
215 |
+
Returns
|
216 |
+
-------
|
217 |
+
dict
|
218 |
+
A dictionary with the following keys:
|
219 |
+
|
220 |
+
"z" : Tensor[B x D x T]
|
221 |
+
Quantized continuous representation of input
|
222 |
+
"codes" : Tensor[B x N x T]
|
223 |
+
Codebook indices for each codebook
|
224 |
+
(quantized discrete representation of input)
|
225 |
+
"latents" : Tensor[B x N*D x T]
|
226 |
+
Projected latents (continuous representation of input before quantization)
|
227 |
+
"vq/commitment_loss" : Tensor[1]
|
228 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
229 |
+
entries
|
230 |
+
"vq/codebook_loss" : Tensor[1]
|
231 |
+
Codebook loss to update the codebook
|
232 |
+
"""
|
233 |
+
z_q = 0
|
234 |
+
residual = z
|
235 |
+
commitment_loss = 0
|
236 |
+
codebook_loss = 0
|
237 |
+
|
238 |
+
codebook_indices = []
|
239 |
+
latents = []
|
240 |
+
|
241 |
+
if n_quantizers is None:
|
242 |
+
n_quantizers = self.n_codebooks
|
243 |
+
if self.training:
|
244 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
245 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
246 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
247 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
248 |
+
n_quantizers = n_quantizers.to(z.device)
|
249 |
+
|
250 |
+
for i, quantizer in enumerate(self.quantizers):
|
251 |
+
if self.training is False and i >= n_quantizers:
|
252 |
+
break
|
253 |
+
|
254 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
255 |
+
residual
|
256 |
+
)
|
257 |
+
|
258 |
+
# Create mask to apply quantizer dropout
|
259 |
+
mask = (
|
260 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
261 |
+
)
|
262 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
263 |
+
residual = residual - z_q_i
|
264 |
+
|
265 |
+
# Sum losses
|
266 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
267 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
268 |
+
|
269 |
+
codebook_indices.append(indices_i)
|
270 |
+
latents.append(z_e_i)
|
271 |
+
|
272 |
+
codes = torch.stack(codebook_indices, dim=1)
|
273 |
+
latents = torch.cat(latents, dim=1)
|
274 |
+
|
275 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
276 |
+
|
277 |
+
def from_codes(self, codes: torch.Tensor):
|
278 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
279 |
+
Parameters
|
280 |
+
----------
|
281 |
+
codes : Tensor[B x N x T]
|
282 |
+
Quantized discrete representation of input
|
283 |
+
Returns
|
284 |
+
-------
|
285 |
+
Tensor[B x D x T]
|
286 |
+
Quantized continuous representation of input
|
287 |
+
"""
|
288 |
+
z_q = 0.0
|
289 |
+
z_p = []
|
290 |
+
n_codebooks = codes.shape[1]
|
291 |
+
for i in range(n_codebooks):
|
292 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
293 |
+
z_p.append(z_p_i)
|
294 |
+
|
295 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
296 |
+
z_q = z_q + z_q_i
|
297 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
298 |
+
|
299 |
+
def from_latents(self, latents: torch.Tensor):
|
300 |
+
"""Given the unquantized latents, reconstruct the
|
301 |
+
continuous representation after quantization.
|
302 |
+
|
303 |
+
Parameters
|
304 |
+
----------
|
305 |
+
latents : Tensor[B x N x T]
|
306 |
+
Continuous representation of input after projection
|
307 |
+
|
308 |
+
Returns
|
309 |
+
-------
|
310 |
+
Tensor[B x D x T]
|
311 |
+
Quantized representation of full-projected space
|
312 |
+
Tensor[B x D x T]
|
313 |
+
Quantized representation of latent space
|
314 |
+
"""
|
315 |
+
z_q = 0
|
316 |
+
z_p = []
|
317 |
+
codes = []
|
318 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
319 |
+
|
320 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
321 |
+
0
|
322 |
+
]
|
323 |
+
for i in range(n_codebooks):
|
324 |
+
j, k = dims[i], dims[i + 1]
|
325 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
326 |
+
z_p.append(z_p_i)
|
327 |
+
codes.append(codes_i)
|
328 |
+
|
329 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
330 |
+
z_q = z_q + z_q_i
|
331 |
+
|
332 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
337 |
+
x = torch.randn(16, 512, 80)
|
338 |
+
y = rvq(x)
|
339 |
+
print(y["latents"].shape)
|