Staticaliza commited on
Commit
9b5f7e0
·
verified ·
1 Parent(s): 089208d

Delete dac

Browse files
dac/__init__.py DELETED
@@ -1,16 +0,0 @@
1
- __version__ = "1.0.0"
2
-
3
- # preserved here for legacy reasons
4
- __model_version__ = "latest"
5
-
6
- import audiotools
7
-
8
- audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
- audiotools.ml.BaseModel.EXTERN += ["einops"]
10
-
11
-
12
- from . import nn
13
- from . import model
14
- from . import utils
15
- from .model import DAC
16
- from .model import DACFile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/__main__.py DELETED
@@ -1,36 +0,0 @@
1
- import sys
2
-
3
- import argbind
4
-
5
- from dac.utils import download
6
- from dac.utils.decode import decode
7
- from dac.utils.encode import encode
8
-
9
- STAGES = ["encode", "decode", "download"]
10
-
11
-
12
- def run(stage: str):
13
- """Run stages.
14
-
15
- Parameters
16
- ----------
17
- stage : str
18
- Stage to run
19
- """
20
- if stage not in STAGES:
21
- raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
- stage_fn = globals()[stage]
23
-
24
- if stage == "download":
25
- stage_fn()
26
- return
27
-
28
- stage_fn()
29
-
30
-
31
- if __name__ == "__main__":
32
- group = sys.argv.pop(1)
33
- args = argbind.parse_args(group=group)
34
-
35
- with argbind.scope(args):
36
- run(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/model/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .base import CodecMixin
2
- from .base import DACFile
3
- from .dac import DAC
4
- from .discriminator import Discriminator
 
 
 
 
 
dac/model/base.py DELETED
@@ -1,294 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- from pathlib import Path
4
- from typing import Union
5
-
6
- import numpy as np
7
- import torch
8
- import tqdm
9
- from audiotools import AudioSignal
10
- from torch import nn
11
-
12
- SUPPORTED_VERSIONS = ["1.0.0"]
13
-
14
-
15
- @dataclass
16
- class DACFile:
17
- codes: torch.Tensor
18
-
19
- # Metadata
20
- chunk_length: int
21
- original_length: int
22
- input_db: float
23
- channels: int
24
- sample_rate: int
25
- padding: bool
26
- dac_version: str
27
-
28
- def save(self, path):
29
- artifacts = {
30
- "codes": self.codes.numpy().astype(np.uint16),
31
- "metadata": {
32
- "input_db": self.input_db.numpy().astype(np.float32),
33
- "original_length": self.original_length,
34
- "sample_rate": self.sample_rate,
35
- "chunk_length": self.chunk_length,
36
- "channels": self.channels,
37
- "padding": self.padding,
38
- "dac_version": SUPPORTED_VERSIONS[-1],
39
- },
40
- }
41
- path = Path(path).with_suffix(".dac")
42
- with open(path, "wb") as f:
43
- np.save(f, artifacts)
44
- return path
45
-
46
- @classmethod
47
- def load(cls, path):
48
- artifacts = np.load(path, allow_pickle=True)[()]
49
- codes = torch.from_numpy(artifacts["codes"].astype(int))
50
- if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
- raise RuntimeError(
52
- f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
- )
54
- return cls(codes=codes, **artifacts["metadata"])
55
-
56
-
57
- class CodecMixin:
58
- @property
59
- def padding(self):
60
- if not hasattr(self, "_padding"):
61
- self._padding = True
62
- return self._padding
63
-
64
- @padding.setter
65
- def padding(self, value):
66
- assert isinstance(value, bool)
67
-
68
- layers = [
69
- l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
- ]
71
-
72
- for layer in layers:
73
- if value:
74
- if hasattr(layer, "original_padding"):
75
- layer.padding = layer.original_padding
76
- else:
77
- layer.original_padding = layer.padding
78
- layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
-
80
- self._padding = value
81
-
82
- def get_delay(self):
83
- # Any number works here, delay is invariant to input length
84
- l_out = self.get_output_length(0)
85
- L = l_out
86
-
87
- layers = []
88
- for layer in self.modules():
89
- if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
- layers.append(layer)
91
-
92
- for layer in reversed(layers):
93
- d = layer.dilation[0]
94
- k = layer.kernel_size[0]
95
- s = layer.stride[0]
96
-
97
- if isinstance(layer, nn.ConvTranspose1d):
98
- L = ((L - d * (k - 1) - 1) / s) + 1
99
- elif isinstance(layer, nn.Conv1d):
100
- L = (L - 1) * s + d * (k - 1) + 1
101
-
102
- L = math.ceil(L)
103
-
104
- l_in = L
105
-
106
- return (l_in - l_out) // 2
107
-
108
- def get_output_length(self, input_length):
109
- L = input_length
110
- # Calculate output length
111
- for layer in self.modules():
112
- if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
- d = layer.dilation[0]
114
- k = layer.kernel_size[0]
115
- s = layer.stride[0]
116
-
117
- if isinstance(layer, nn.Conv1d):
118
- L = ((L - d * (k - 1) - 1) / s) + 1
119
- elif isinstance(layer, nn.ConvTranspose1d):
120
- L = (L - 1) * s + d * (k - 1) + 1
121
-
122
- L = math.floor(L)
123
- return L
124
-
125
- @torch.no_grad()
126
- def compress(
127
- self,
128
- audio_path_or_signal: Union[str, Path, AudioSignal],
129
- win_duration: float = 1.0,
130
- verbose: bool = False,
131
- normalize_db: float = -16,
132
- n_quantizers: int = None,
133
- ) -> DACFile:
134
- """Processes an audio signal from a file or AudioSignal object into
135
- discrete codes. This function processes the signal in short windows,
136
- using constant GPU memory.
137
-
138
- Parameters
139
- ----------
140
- audio_path_or_signal : Union[str, Path, AudioSignal]
141
- audio signal to reconstruct
142
- win_duration : float, optional
143
- window duration in seconds, by default 5.0
144
- verbose : bool, optional
145
- by default False
146
- normalize_db : float, optional
147
- normalize db, by default -16
148
-
149
- Returns
150
- -------
151
- DACFile
152
- Object containing compressed codes and metadata
153
- required for decompression
154
- """
155
- audio_signal = audio_path_or_signal
156
- if isinstance(audio_signal, (str, Path)):
157
- audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
-
159
- self.eval()
160
- original_padding = self.padding
161
- original_device = audio_signal.device
162
-
163
- audio_signal = audio_signal.clone()
164
- original_sr = audio_signal.sample_rate
165
-
166
- resample_fn = audio_signal.resample
167
- loudness_fn = audio_signal.loudness
168
-
169
- # If audio is > 10 minutes long, use the ffmpeg versions
170
- if audio_signal.signal_duration >= 10 * 60 * 60:
171
- resample_fn = audio_signal.ffmpeg_resample
172
- loudness_fn = audio_signal.ffmpeg_loudness
173
-
174
- original_length = audio_signal.signal_length
175
- resample_fn(self.sample_rate)
176
- input_db = loudness_fn()
177
-
178
- if normalize_db is not None:
179
- audio_signal.normalize(normalize_db)
180
- audio_signal.ensure_max_of_audio()
181
-
182
- nb, nac, nt = audio_signal.audio_data.shape
183
- audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
184
- win_duration = (
185
- audio_signal.signal_duration if win_duration is None else win_duration
186
- )
187
-
188
- if audio_signal.signal_duration <= win_duration:
189
- # Unchunked compression (used if signal length < win duration)
190
- self.padding = True
191
- n_samples = nt
192
- hop = nt
193
- else:
194
- # Chunked inference
195
- self.padding = False
196
- # Zero-pad signal on either side by the delay
197
- audio_signal.zero_pad(self.delay, self.delay)
198
- n_samples = int(win_duration * self.sample_rate)
199
- # Round n_samples to nearest hop length multiple
200
- n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
201
- hop = self.get_output_length(n_samples)
202
-
203
- codes = []
204
- range_fn = range if not verbose else tqdm.trange
205
-
206
- for i in range_fn(0, nt, hop):
207
- x = audio_signal[..., i : i + n_samples]
208
- x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
209
-
210
- audio_data = x.audio_data.to(self.device)
211
- audio_data = self.preprocess(audio_data, self.sample_rate)
212
- _, c, _, _, _ = self.encode(audio_data, n_quantizers)
213
- codes.append(c.to(original_device))
214
- chunk_length = c.shape[-1]
215
-
216
- codes = torch.cat(codes, dim=-1)
217
-
218
- dac_file = DACFile(
219
- codes=codes,
220
- chunk_length=chunk_length,
221
- original_length=original_length,
222
- input_db=input_db,
223
- channels=nac,
224
- sample_rate=original_sr,
225
- padding=self.padding,
226
- dac_version=SUPPORTED_VERSIONS[-1],
227
- )
228
-
229
- if n_quantizers is not None:
230
- codes = codes[:, :n_quantizers, :]
231
-
232
- self.padding = original_padding
233
- return dac_file
234
-
235
- @torch.no_grad()
236
- def decompress(
237
- self,
238
- obj: Union[str, Path, DACFile],
239
- verbose: bool = False,
240
- ) -> AudioSignal:
241
- """Reconstruct audio from a given .dac file
242
-
243
- Parameters
244
- ----------
245
- obj : Union[str, Path, DACFile]
246
- .dac file location or corresponding DACFile object.
247
- verbose : bool, optional
248
- Prints progress if True, by default False
249
-
250
- Returns
251
- -------
252
- AudioSignal
253
- Object with the reconstructed audio
254
- """
255
- self.eval()
256
- if isinstance(obj, (str, Path)):
257
- obj = DACFile.load(obj)
258
-
259
- original_padding = self.padding
260
- self.padding = obj.padding
261
-
262
- range_fn = range if not verbose else tqdm.trange
263
- codes = obj.codes
264
- original_device = codes.device
265
- chunk_length = obj.chunk_length
266
- recons = []
267
-
268
- for i in range_fn(0, codes.shape[-1], chunk_length):
269
- c = codes[..., i : i + chunk_length].to(self.device)
270
- z = self.quantizer.from_codes(c)[0]
271
- r = self.decode(z)
272
- recons.append(r.to(original_device))
273
-
274
- recons = torch.cat(recons, dim=-1)
275
- recons = AudioSignal(recons, self.sample_rate)
276
-
277
- resample_fn = recons.resample
278
- loudness_fn = recons.loudness
279
-
280
- # If audio is > 10 minutes long, use the ffmpeg versions
281
- if recons.signal_duration >= 10 * 60 * 60:
282
- resample_fn = recons.ffmpeg_resample
283
- loudness_fn = recons.ffmpeg_loudness
284
-
285
- recons.normalize(obj.input_db)
286
- resample_fn(obj.sample_rate)
287
- recons = recons[..., : obj.original_length]
288
- loudness_fn()
289
- recons.audio_data = recons.audio_data.reshape(
290
- -1, obj.channels, obj.original_length
291
- )
292
-
293
- self.padding = original_padding
294
- return recons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/model/dac.py DELETED
@@ -1,400 +0,0 @@
1
- import math
2
- from typing import List
3
- from typing import Union
4
-
5
- import numpy as np
6
- import torch
7
- from audiotools import AudioSignal
8
- from audiotools.ml import BaseModel
9
- from torch import nn
10
-
11
- from .base import CodecMixin
12
- from dac.nn.layers import Snake1d
13
- from dac.nn.layers import WNConv1d
14
- from dac.nn.layers import WNConvTranspose1d
15
- from dac.nn.quantize import ResidualVectorQuantize
16
- from .encodec import SConv1d, SConvTranspose1d, SLSTM
17
-
18
-
19
- def init_weights(m):
20
- if isinstance(m, nn.Conv1d):
21
- nn.init.trunc_normal_(m.weight, std=0.02)
22
- nn.init.constant_(m.bias, 0)
23
-
24
-
25
- class ResidualUnit(nn.Module):
26
- def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
27
- super().__init__()
28
- conv1d_type = SConv1d# if causal else WNConv1d
29
- pad = ((7 - 1) * dilation) // 2
30
- self.block = nn.Sequential(
31
- Snake1d(dim),
32
- conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'),
33
- Snake1d(dim),
34
- conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'),
35
- )
36
-
37
- def forward(self, x):
38
- y = self.block(x)
39
- pad = (x.shape[-1] - y.shape[-1]) // 2
40
- if pad > 0:
41
- x = x[..., pad:-pad]
42
- return x + y
43
-
44
-
45
- class EncoderBlock(nn.Module):
46
- def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False):
47
- super().__init__()
48
- conv1d_type = SConv1d# if causal else WNConv1d
49
- self.block = nn.Sequential(
50
- ResidualUnit(dim // 2, dilation=1, causal=causal),
51
- ResidualUnit(dim // 2, dilation=3, causal=causal),
52
- ResidualUnit(dim // 2, dilation=9, causal=causal),
53
- Snake1d(dim // 2),
54
- conv1d_type(
55
- dim // 2,
56
- dim,
57
- kernel_size=2 * stride,
58
- stride=stride,
59
- padding=math.ceil(stride / 2),
60
- causal=causal,
61
- norm='weight_norm',
62
- ),
63
- )
64
-
65
- def forward(self, x):
66
- return self.block(x)
67
-
68
-
69
- class Encoder(nn.Module):
70
- def __init__(
71
- self,
72
- d_model: int = 64,
73
- strides: list = [2, 4, 8, 8],
74
- d_latent: int = 64,
75
- causal: bool = False,
76
- lstm: int = 2,
77
- ):
78
- super().__init__()
79
- conv1d_type = SConv1d# if causal else WNConv1d
80
- # Create first convolution
81
- self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
82
-
83
- # Create EncoderBlocks that double channels as they downsample by `stride`
84
- for stride in strides:
85
- d_model *= 2
86
- self.block += [EncoderBlock(d_model, stride=stride, causal=causal)]
87
-
88
- # Add LSTM if needed
89
- self.use_lstm = lstm
90
- if lstm:
91
- self.block += [SLSTM(d_model, lstm)]
92
-
93
- # Create last convolution
94
- self.block += [
95
- Snake1d(d_model),
96
- conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'),
97
- ]
98
-
99
- # Wrap black into nn.Sequential
100
- self.block = nn.Sequential(*self.block)
101
- self.enc_dim = d_model
102
-
103
- def forward(self, x):
104
- return self.block(x)
105
-
106
- def reset_cache(self):
107
- # recursively find all submodules named SConv1d in self.block and use their reset_cache method
108
- def reset_cache(m):
109
- if isinstance(m, SConv1d) or isinstance(m, SLSTM):
110
- m.reset_cache()
111
- return
112
- for child in m.children():
113
- reset_cache(child)
114
-
115
- reset_cache(self.block)
116
-
117
-
118
- class DecoderBlock(nn.Module):
119
- def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False):
120
- super().__init__()
121
- conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d
122
- self.block = nn.Sequential(
123
- Snake1d(input_dim),
124
- conv1d_type(
125
- input_dim,
126
- output_dim,
127
- kernel_size=2 * stride,
128
- stride=stride,
129
- padding=math.ceil(stride / 2),
130
- causal=causal,
131
- norm='weight_norm'
132
- ),
133
- ResidualUnit(output_dim, dilation=1, causal=causal),
134
- ResidualUnit(output_dim, dilation=3, causal=causal),
135
- ResidualUnit(output_dim, dilation=9, causal=causal),
136
- )
137
-
138
- def forward(self, x):
139
- return self.block(x)
140
-
141
-
142
- class Decoder(nn.Module):
143
- def __init__(
144
- self,
145
- input_channel,
146
- channels,
147
- rates,
148
- d_out: int = 1,
149
- causal: bool = False,
150
- lstm: int = 2,
151
- ):
152
- super().__init__()
153
- conv1d_type = SConv1d# if causal else WNConv1d
154
- # Add first conv layer
155
- layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
156
-
157
- if lstm:
158
- layers += [SLSTM(channels, num_layers=lstm)]
159
-
160
- # Add upsampling + MRF blocks
161
- for i, stride in enumerate(rates):
162
- input_dim = channels // 2**i
163
- output_dim = channels // 2 ** (i + 1)
164
- layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)]
165
-
166
- # Add final conv layer
167
- layers += [
168
- Snake1d(output_dim),
169
- conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'),
170
- nn.Tanh(),
171
- ]
172
-
173
- self.model = nn.Sequential(*layers)
174
-
175
- def forward(self, x):
176
- return self.model(x)
177
-
178
-
179
- class DAC(BaseModel, CodecMixin):
180
- def __init__(
181
- self,
182
- encoder_dim: int = 64,
183
- encoder_rates: List[int] = [2, 4, 8, 8],
184
- latent_dim: int = None,
185
- decoder_dim: int = 1536,
186
- decoder_rates: List[int] = [8, 8, 4, 2],
187
- n_codebooks: int = 9,
188
- codebook_size: int = 1024,
189
- codebook_dim: Union[int, list] = 8,
190
- quantizer_dropout: bool = False,
191
- sample_rate: int = 44100,
192
- lstm: int = 2,
193
- causal: bool = False,
194
- ):
195
- super().__init__()
196
-
197
- self.encoder_dim = encoder_dim
198
- self.encoder_rates = encoder_rates
199
- self.decoder_dim = decoder_dim
200
- self.decoder_rates = decoder_rates
201
- self.sample_rate = sample_rate
202
-
203
- if latent_dim is None:
204
- latent_dim = encoder_dim * (2 ** len(encoder_rates))
205
-
206
- self.latent_dim = latent_dim
207
-
208
- self.hop_length = np.prod(encoder_rates)
209
- self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm)
210
-
211
- self.n_codebooks = n_codebooks
212
- self.codebook_size = codebook_size
213
- self.codebook_dim = codebook_dim
214
- self.quantizer = ResidualVectorQuantize(
215
- input_dim=latent_dim,
216
- n_codebooks=n_codebooks,
217
- codebook_size=codebook_size,
218
- codebook_dim=codebook_dim,
219
- quantizer_dropout=quantizer_dropout,
220
- )
221
-
222
- self.decoder = Decoder(
223
- latent_dim,
224
- decoder_dim,
225
- decoder_rates,
226
- lstm=lstm,
227
- causal=causal,
228
- )
229
- self.sample_rate = sample_rate
230
- self.apply(init_weights)
231
-
232
- self.delay = self.get_delay()
233
-
234
- def preprocess(self, audio_data, sample_rate):
235
- if sample_rate is None:
236
- sample_rate = self.sample_rate
237
- assert sample_rate == self.sample_rate
238
-
239
- length = audio_data.shape[-1]
240
- right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
241
- audio_data = nn.functional.pad(audio_data, (0, right_pad))
242
-
243
- return audio_data
244
-
245
- def encode(
246
- self,
247
- audio_data: torch.Tensor,
248
- n_quantizers: int = None,
249
- ):
250
- """Encode given audio data and return quantized latent codes
251
-
252
- Parameters
253
- ----------
254
- audio_data : Tensor[B x 1 x T]
255
- Audio data to encode
256
- n_quantizers : int, optional
257
- Number of quantizers to use, by default None
258
- If None, all quantizers are used.
259
-
260
- Returns
261
- -------
262
- dict
263
- A dictionary with the following keys:
264
- "z" : Tensor[B x D x T]
265
- Quantized continuous representation of input
266
- "codes" : Tensor[B x N x T]
267
- Codebook indices for each codebook
268
- (quantized discrete representation of input)
269
- "latents" : Tensor[B x N*D x T]
270
- Projected latents (continuous representation of input before quantization)
271
- "vq/commitment_loss" : Tensor[1]
272
- Commitment loss to train encoder to predict vectors closer to codebook
273
- entries
274
- "vq/codebook_loss" : Tensor[1]
275
- Codebook loss to update the codebook
276
- "length" : int
277
- Number of samples in input audio
278
- """
279
- z = self.encoder(audio_data)
280
- z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
281
- z, n_quantizers
282
- )
283
- return z, codes, latents, commitment_loss, codebook_loss
284
-
285
- def decode(self, z: torch.Tensor):
286
- """Decode given latent codes and return audio data
287
-
288
- Parameters
289
- ----------
290
- z : Tensor[B x D x T]
291
- Quantized continuous representation of input
292
- length : int, optional
293
- Number of samples in output audio, by default None
294
-
295
- Returns
296
- -------
297
- dict
298
- A dictionary with the following keys:
299
- "audio" : Tensor[B x 1 x length]
300
- Decoded audio data.
301
- """
302
- return self.decoder(z)
303
-
304
- def forward(
305
- self,
306
- audio_data: torch.Tensor,
307
- sample_rate: int = None,
308
- n_quantizers: int = None,
309
- ):
310
- """Model forward pass
311
-
312
- Parameters
313
- ----------
314
- audio_data : Tensor[B x 1 x T]
315
- Audio data to encode
316
- sample_rate : int, optional
317
- Sample rate of audio data in Hz, by default None
318
- If None, defaults to `self.sample_rate`
319
- n_quantizers : int, optional
320
- Number of quantizers to use, by default None.
321
- If None, all quantizers are used.
322
-
323
- Returns
324
- -------
325
- dict
326
- A dictionary with the following keys:
327
- "z" : Tensor[B x D x T]
328
- Quantized continuous representation of input
329
- "codes" : Tensor[B x N x T]
330
- Codebook indices for each codebook
331
- (quantized discrete representation of input)
332
- "latents" : Tensor[B x N*D x T]
333
- Projected latents (continuous representation of input before quantization)
334
- "vq/commitment_loss" : Tensor[1]
335
- Commitment loss to train encoder to predict vectors closer to codebook
336
- entries
337
- "vq/codebook_loss" : Tensor[1]
338
- Codebook loss to update the codebook
339
- "length" : int
340
- Number of samples in input audio
341
- "audio" : Tensor[B x 1 x length]
342
- Decoded audio data.
343
- """
344
- length = audio_data.shape[-1]
345
- audio_data = self.preprocess(audio_data, sample_rate)
346
- z, codes, latents, commitment_loss, codebook_loss = self.encode(
347
- audio_data, n_quantizers
348
- )
349
-
350
- x = self.decode(z)
351
- return {
352
- "audio": x[..., :length],
353
- "z": z,
354
- "codes": codes,
355
- "latents": latents,
356
- "vq/commitment_loss": commitment_loss,
357
- "vq/codebook_loss": codebook_loss,
358
- }
359
-
360
-
361
- if __name__ == "__main__":
362
- import numpy as np
363
- from functools import partial
364
-
365
- model = DAC().to("cpu")
366
-
367
- for n, m in model.named_modules():
368
- o = m.extra_repr()
369
- p = sum([np.prod(p.size()) for p in m.parameters()])
370
- fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
371
- setattr(m, "extra_repr", partial(fn, o=o, p=p))
372
- print(model)
373
- print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
374
-
375
- length = 88200 * 2
376
- x = torch.randn(1, 1, length).to(model.device)
377
- x.requires_grad_(True)
378
- x.retain_grad()
379
-
380
- # Make a forward pass
381
- out = model(x)["audio"]
382
- print("Input shape:", x.shape)
383
- print("Output shape:", out.shape)
384
-
385
- # Create gradient variable
386
- grad = torch.zeros_like(out)
387
- grad[:, :, grad.shape[-1] // 2] = 1
388
-
389
- # Make a backward pass
390
- out.backward(grad)
391
-
392
- # Check non-zero values
393
- gradmap = x.grad.squeeze(0)
394
- gradmap = (gradmap != 0).sum(0) # sum across features
395
- rf = (gradmap != 0).sum()
396
-
397
- print(f"Receptive field: {rf.item()}")
398
-
399
- x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
400
- model.decompress(model.compress(x, verbose=True), verbose=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/model/discriminator.py DELETED
@@ -1,228 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from audiotools import AudioSignal
5
- from audiotools import ml
6
- from audiotools import STFTParams
7
- from einops import rearrange
8
- from torch.nn.utils import weight_norm
9
-
10
-
11
- def WNConv1d(*args, **kwargs):
12
- act = kwargs.pop("act", True)
13
- conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
- if not act:
15
- return conv
16
- return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
-
18
-
19
- def WNConv2d(*args, **kwargs):
20
- act = kwargs.pop("act", True)
21
- conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
- if not act:
23
- return conv
24
- return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
-
26
-
27
- class MPD(nn.Module):
28
- def __init__(self, period):
29
- super().__init__()
30
- self.period = period
31
- self.convs = nn.ModuleList(
32
- [
33
- WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
- WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
- WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
- WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
- WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
- ]
39
- )
40
- self.conv_post = WNConv2d(
41
- 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
- )
43
-
44
- def pad_to_period(self, x):
45
- t = x.shape[-1]
46
- x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
- return x
48
-
49
- def forward(self, x):
50
- fmap = []
51
-
52
- x = self.pad_to_period(x)
53
- x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
-
55
- for layer in self.convs:
56
- x = layer(x)
57
- fmap.append(x)
58
-
59
- x = self.conv_post(x)
60
- fmap.append(x)
61
-
62
- return fmap
63
-
64
-
65
- class MSD(nn.Module):
66
- def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
- super().__init__()
68
- self.convs = nn.ModuleList(
69
- [
70
- WNConv1d(1, 16, 15, 1, padding=7),
71
- WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
- WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
- WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
- WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
- WNConv1d(1024, 1024, 5, 1, padding=2),
76
- ]
77
- )
78
- self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
- self.sample_rate = sample_rate
80
- self.rate = rate
81
-
82
- def forward(self, x):
83
- x = AudioSignal(x, self.sample_rate)
84
- x.resample(self.sample_rate // self.rate)
85
- x = x.audio_data
86
-
87
- fmap = []
88
-
89
- for l in self.convs:
90
- x = l(x)
91
- fmap.append(x)
92
- x = self.conv_post(x)
93
- fmap.append(x)
94
-
95
- return fmap
96
-
97
-
98
- BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
-
100
-
101
- class MRD(nn.Module):
102
- def __init__(
103
- self,
104
- window_length: int,
105
- hop_factor: float = 0.25,
106
- sample_rate: int = 44100,
107
- bands: list = BANDS,
108
- ):
109
- """Complex multi-band spectrogram discriminator.
110
- Parameters
111
- ----------
112
- window_length : int
113
- Window length of STFT.
114
- hop_factor : float, optional
115
- Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
- sample_rate : int, optional
117
- Sampling rate of audio in Hz, by default 44100
118
- bands : list, optional
119
- Bands to run discriminator over.
120
- """
121
- super().__init__()
122
-
123
- self.window_length = window_length
124
- self.hop_factor = hop_factor
125
- self.sample_rate = sample_rate
126
- self.stft_params = STFTParams(
127
- window_length=window_length,
128
- hop_length=int(window_length * hop_factor),
129
- match_stride=True,
130
- )
131
-
132
- n_fft = window_length // 2 + 1
133
- bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
- self.bands = bands
135
-
136
- ch = 32
137
- convs = lambda: nn.ModuleList(
138
- [
139
- WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
- WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
- ]
145
- )
146
- self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
- self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
-
149
- def spectrogram(self, x):
150
- x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
- x = torch.view_as_real(x.stft())
152
- x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
- # Split into bands
154
- x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
- return x_bands
156
-
157
- def forward(self, x):
158
- x_bands = self.spectrogram(x)
159
- fmap = []
160
-
161
- x = []
162
- for band, stack in zip(x_bands, self.band_convs):
163
- for layer in stack:
164
- band = layer(band)
165
- fmap.append(band)
166
- x.append(band)
167
-
168
- x = torch.cat(x, dim=-1)
169
- x = self.conv_post(x)
170
- fmap.append(x)
171
-
172
- return fmap
173
-
174
-
175
- class Discriminator(nn.Module):
176
- def __init__(
177
- self,
178
- rates: list = [],
179
- periods: list = [2, 3, 5, 7, 11],
180
- fft_sizes: list = [2048, 1024, 512],
181
- sample_rate: int = 44100,
182
- bands: list = BANDS,
183
- ):
184
- """Discriminator that combines multiple discriminators.
185
-
186
- Parameters
187
- ----------
188
- rates : list, optional
189
- sampling rates (in Hz) to run MSD at, by default []
190
- If empty, MSD is not used.
191
- periods : list, optional
192
- periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
- fft_sizes : list, optional
194
- Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
- sample_rate : int, optional
196
- Sampling rate of audio in Hz, by default 44100
197
- bands : list, optional
198
- Bands to run MRD at, by default `BANDS`
199
- """
200
- super().__init__()
201
- discs = []
202
- discs += [MPD(p) for p in periods]
203
- discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
- discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
- self.discriminators = nn.ModuleList(discs)
206
-
207
- def preprocess(self, y):
208
- # Remove DC offset
209
- y = y - y.mean(dim=-1, keepdims=True)
210
- # Peak normalize the volume of input audio
211
- y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
- return y
213
-
214
- def forward(self, x):
215
- x = self.preprocess(x)
216
- fmaps = [d(x) for d in self.discriminators]
217
- return fmaps
218
-
219
-
220
- if __name__ == "__main__":
221
- disc = Discriminator()
222
- x = torch.zeros(1, 1, 44100)
223
- results = disc(x)
224
- for i, result in enumerate(results):
225
- print(f"disc{i}")
226
- for i, r in enumerate(result):
227
- print(r.shape, r.mean(), r.min(), r.max())
228
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/model/encodec.py DELETED
@@ -1,320 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Convolutional layers wrappers and utilities."""
8
-
9
- import math
10
- import typing as tp
11
- import warnings
12
-
13
- import torch
14
- from torch import nn
15
- from torch.nn import functional as F
16
- from torch.nn.utils import spectral_norm, weight_norm
17
-
18
- import typing as tp
19
-
20
- import einops
21
-
22
-
23
- class ConvLayerNorm(nn.LayerNorm):
24
- """
25
- Convolution-friendly LayerNorm that moves channels to last dimensions
26
- before running the normalization and moves them back to original position right after.
27
- """
28
- def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29
- super().__init__(normalized_shape, **kwargs)
30
-
31
- def forward(self, x):
32
- x = einops.rearrange(x, 'b ... t -> b t ...')
33
- x = super().forward(x)
34
- x = einops.rearrange(x, 'b t ... -> b ... t')
35
- return
36
-
37
-
38
- CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39
- 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40
-
41
-
42
- def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43
- assert norm in CONV_NORMALIZATIONS
44
- if norm == 'weight_norm':
45
- return weight_norm(module)
46
- elif norm == 'spectral_norm':
47
- return spectral_norm(module)
48
- else:
49
- # We already check was in CONV_NORMALIZATION, so any other choice
50
- # doesn't need reparametrization.
51
- return module
52
-
53
-
54
- def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55
- """Return the proper normalization module. If causal is True, this will ensure the returned
56
- module is causal, or return an error if the normalization doesn't support causal evaluation.
57
- """
58
- assert norm in CONV_NORMALIZATIONS
59
- if norm == 'layer_norm':
60
- assert isinstance(module, nn.modules.conv._ConvNd)
61
- return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
- elif norm == 'time_group_norm':
63
- if causal:
64
- raise ValueError("GroupNorm doesn't support causal evaluation.")
65
- assert isinstance(module, nn.modules.conv._ConvNd)
66
- return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
- else:
68
- return nn.Identity()
69
-
70
-
71
- def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
- padding_total: int = 0) -> int:
73
- """See `pad_for_conv1d`.
74
- """
75
- length = x.shape[-1]
76
- n_frames = (length - kernel_size + padding_total) / stride + 1
77
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78
- return ideal_length - length
79
-
80
-
81
- def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82
- """Pad for a convolution to make sure that the last window is full.
83
- Extra padding is added at the end. This is required to ensure that we can rebuild
84
- an output of the same length, as otherwise, even with padding, some time steps
85
- might get removed.
86
- For instance, with total padding = 4, kernel size = 4, stride = 2:
87
- 0 0 1 2 3 4 5 0 0 # (0s are padding)
88
- 1 2 3 # (output frames of a convolution, last 0 is never used)
89
- 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90
- 1 2 3 4 # once you removed padding, we are missing one time step !
91
- """
92
- extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93
- return F.pad(x, (0, extra_padding))
94
-
95
-
96
- def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97
- """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98
- If this is the case, we insert extra 0 padding to the right before the reflection happen.
99
- """
100
- length = x.shape[-1]
101
- padding_left, padding_right = paddings
102
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
- if mode == 'reflect':
104
- max_pad = max(padding_left, padding_right)
105
- extra_pad = 0
106
- if length <= max_pad:
107
- extra_pad = max_pad - length + 1
108
- x = F.pad(x, (0, extra_pad))
109
- padded = F.pad(x, paddings, mode, value)
110
- end = padded.shape[-1] - extra_pad
111
- return padded[..., :end]
112
- else:
113
- return F.pad(x, paddings, mode, value)
114
-
115
-
116
- def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117
- """Remove padding from x, handling properly zero padding. Only for 1d!"""
118
- padding_left, padding_right = paddings
119
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
- assert (padding_left + padding_right) <= x.shape[-1]
121
- end = x.shape[-1] - padding_right
122
- return x[..., padding_left: end]
123
-
124
-
125
- class NormConv1d(nn.Module):
126
- """Wrapper around Conv1d and normalization applied to this conv
127
- to provide a uniform interface across normalization approaches.
128
- """
129
- def __init__(self, *args, causal: bool = False, norm: str = 'none',
130
- norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
- super().__init__()
132
- self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133
- self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134
- self.norm_type = norm
135
-
136
- def forward(self, x):
137
- x = self.conv(x)
138
- x = self.norm(x)
139
- return x
140
-
141
-
142
- class NormConv2d(nn.Module):
143
- """Wrapper around Conv2d and normalization applied to this conv
144
- to provide a uniform interface across normalization approaches.
145
- """
146
- def __init__(self, *args, norm: str = 'none',
147
- norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
- super().__init__()
149
- self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150
- self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151
- self.norm_type = norm
152
-
153
- def forward(self, x):
154
- x = self.conv(x)
155
- x = self.norm(x)
156
- return x
157
-
158
-
159
- class NormConvTranspose1d(nn.Module):
160
- """Wrapper around ConvTranspose1d and normalization applied to this conv
161
- to provide a uniform interface across normalization approaches.
162
- """
163
- def __init__(self, *args, causal: bool = False, norm: str = 'none',
164
- norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
- super().__init__()
166
- self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167
- self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168
- self.norm_type = norm
169
-
170
- def forward(self, x):
171
- x = self.convtr(x)
172
- x = self.norm(x)
173
- return x
174
-
175
-
176
- class NormConvTranspose2d(nn.Module):
177
- """Wrapper around ConvTranspose2d and normalization applied to this conv
178
- to provide a uniform interface across normalization approaches.
179
- """
180
- def __init__(self, *args, norm: str = 'none',
181
- norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
- super().__init__()
183
- self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184
- self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185
-
186
- def forward(self, x):
187
- x = self.convtr(x)
188
- x = self.norm(x)
189
- return x
190
-
191
-
192
- class SConv1d(nn.Module):
193
- """Conv1d with some builtin handling of asymmetric or causal padding
194
- and normalization.
195
- """
196
- def __init__(self, in_channels: int, out_channels: int,
197
- kernel_size: int, stride: int = 1, dilation: int = 1,
198
- groups: int = 1, bias: bool = True, causal: bool = False,
199
- norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200
- pad_mode: str = 'reflect', **kwargs):
201
- super().__init__()
202
- # warn user on unusual setup between dilation and stride
203
- if stride > 1 and dilation > 1:
204
- warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205
- f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206
- self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207
- dilation=dilation, groups=groups, bias=bias, causal=causal,
208
- norm=norm, norm_kwargs=norm_kwargs)
209
- self.causal = causal
210
- self.pad_mode = pad_mode
211
-
212
- self.cache_enabled = False
213
-
214
- def reset_cache(self):
215
- """Reset the cache when starting a new stream."""
216
- self.cache = None
217
- self.cache_enabled = True
218
-
219
- def forward(self, x):
220
- B, C, T = x.shape
221
- kernel_size = self.conv.conv.kernel_size[0]
222
- stride = self.conv.conv.stride[0]
223
- dilation = self.conv.conv.dilation[0]
224
- kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
225
- padding_total = kernel_size - stride
226
- extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
227
-
228
- if self.causal:
229
- # Left padding for causal
230
- if self.cache_enabled and self.cache is not None:
231
- # Concatenate the cache (previous inputs) with the new input for streaming
232
- x = torch.cat([self.cache, x], dim=2)
233
- else:
234
- x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
235
- else:
236
- # Asymmetric padding required for odd strides
237
- padding_right = padding_total // 2
238
- padding_left = padding_total - padding_right
239
- x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
240
-
241
- # Store the most recent input frames for future cache use
242
- if self.cache_enabled:
243
- if self.cache is None:
244
- # Initialize cache with zeros (at the start of streaming)
245
- self.cache = torch.zeros(B, C, kernel_size - 1, device=x.device)
246
- # Update the cache by storing the latest input frames
247
- if kernel_size > 1:
248
- self.cache = x[:, :, -kernel_size + 1:].detach() # Only store the necessary frames
249
-
250
- return self.conv(x)
251
-
252
-
253
-
254
- class SConvTranspose1d(nn.Module):
255
- """ConvTranspose1d with some builtin handling of asymmetric or causal padding
256
- and normalization.
257
- """
258
- def __init__(self, in_channels: int, out_channels: int,
259
- kernel_size: int, stride: int = 1, causal: bool = False,
260
- norm: str = 'none', trim_right_ratio: float = 1.,
261
- norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
262
- super().__init__()
263
- self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
264
- causal=causal, norm=norm, norm_kwargs=norm_kwargs)
265
- self.causal = causal
266
- self.trim_right_ratio = trim_right_ratio
267
- assert self.causal or self.trim_right_ratio == 1., \
268
- "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
269
- assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
270
-
271
- def forward(self, x):
272
- kernel_size = self.convtr.convtr.kernel_size[0]
273
- stride = self.convtr.convtr.stride[0]
274
- padding_total = kernel_size - stride
275
-
276
- y = self.convtr(x)
277
-
278
- # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
279
- # removed at the very end, when keeping only the right length for the output,
280
- # as removing it here would require also passing the length at the matching layer
281
- # in the encoder.
282
- if self.causal:
283
- # Trim the padding on the right according to the specified ratio
284
- # if trim_right_ratio = 1.0, trim everything from right
285
- padding_right = math.ceil(padding_total * self.trim_right_ratio)
286
- padding_left = padding_total - padding_right
287
- y = unpad1d(y, (padding_left, padding_right))
288
- else:
289
- # Asymmetric padding required for odd strides
290
- padding_right = padding_total // 2
291
- padding_left = padding_total - padding_right
292
- y = unpad1d(y, (padding_left, padding_right))
293
- return y
294
-
295
- class SLSTM(nn.Module):
296
- """
297
- LSTM without worrying about the hidden state, nor the layout of the data.
298
- Expects input as convolutional layout.
299
- """
300
- def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
301
- super().__init__()
302
- self.skip = skip
303
- self.lstm = nn.LSTM(dimension, dimension, num_layers)
304
- self.hidden = None
305
- self.cache_enabled = False
306
-
307
- def forward(self, x):
308
- x = x.permute(2, 0, 1)
309
- if self.training or not self.cache_enabled:
310
- y, _ = self.lstm(x)
311
- else:
312
- y, self.hidden = self.lstm(x, self.hidden)
313
- if self.skip:
314
- y = y + x
315
- y = y.permute(1, 2, 0)
316
- return y
317
-
318
- def reset_cache(self):
319
- self.hidden = None
320
- self.cache_enabled = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/nn/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from . import layers
2
- from . import loss
3
- from . import quantize
 
 
 
 
dac/nn/layers.py DELETED
@@ -1,33 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from einops import rearrange
6
- from torch.nn.utils import weight_norm
7
-
8
-
9
- def WNConv1d(*args, **kwargs):
10
- return weight_norm(nn.Conv1d(*args, **kwargs))
11
-
12
-
13
- def WNConvTranspose1d(*args, **kwargs):
14
- return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
-
16
-
17
- # Scripting this brings model speed up 1.4x
18
- @torch.jit.script
19
- def snake(x, alpha):
20
- shape = x.shape
21
- x = x.reshape(shape[0], shape[1], -1)
22
- x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
- x = x.reshape(shape)
24
- return x
25
-
26
-
27
- class Snake1d(nn.Module):
28
- def __init__(self, channels):
29
- super().__init__()
30
- self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
-
32
- def forward(self, x):
33
- return snake(x, self.alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/nn/loss.py DELETED
@@ -1,368 +0,0 @@
1
- import typing
2
- from typing import List
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from audiotools import AudioSignal
7
- from audiotools import STFTParams
8
- from torch import nn
9
-
10
-
11
- class L1Loss(nn.L1Loss):
12
- """L1 Loss between AudioSignals. Defaults
13
- to comparing ``audio_data``, but any
14
- attribute of an AudioSignal can be used.
15
-
16
- Parameters
17
- ----------
18
- attribute : str, optional
19
- Attribute of signal to compare, defaults to ``audio_data``.
20
- weight : float, optional
21
- Weight of this loss, defaults to 1.0.
22
-
23
- Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
- """
25
-
26
- def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
- self.attribute = attribute
28
- self.weight = weight
29
- super().__init__(**kwargs)
30
-
31
- def forward(self, x: AudioSignal, y: AudioSignal):
32
- """
33
- Parameters
34
- ----------
35
- x : AudioSignal
36
- Estimate AudioSignal
37
- y : AudioSignal
38
- Reference AudioSignal
39
-
40
- Returns
41
- -------
42
- torch.Tensor
43
- L1 loss between AudioSignal attributes.
44
- """
45
- if isinstance(x, AudioSignal):
46
- x = getattr(x, self.attribute)
47
- y = getattr(y, self.attribute)
48
- return super().forward(x, y)
49
-
50
-
51
- class SISDRLoss(nn.Module):
52
- """
53
- Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
- of estimated and reference audio signals or aligned features.
55
-
56
- Parameters
57
- ----------
58
- scaling : int, optional
59
- Whether to use scale-invariant (True) or
60
- signal-to-noise ratio (False), by default True
61
- reduction : str, optional
62
- How to reduce across the batch (either 'mean',
63
- 'sum', or none).], by default ' mean'
64
- zero_mean : int, optional
65
- Zero mean the references and estimates before
66
- computing the loss, by default True
67
- clip_min : int, optional
68
- The minimum possible loss value. Helps network
69
- to not focus on making already good examples better, by default None
70
- weight : float, optional
71
- Weight of this loss, defaults to 1.0.
72
-
73
- Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
- """
75
-
76
- def __init__(
77
- self,
78
- scaling: int = True,
79
- reduction: str = "mean",
80
- zero_mean: int = True,
81
- clip_min: int = None,
82
- weight: float = 1.0,
83
- ):
84
- self.scaling = scaling
85
- self.reduction = reduction
86
- self.zero_mean = zero_mean
87
- self.clip_min = clip_min
88
- self.weight = weight
89
- super().__init__()
90
-
91
- def forward(self, x: AudioSignal, y: AudioSignal):
92
- eps = 1e-8
93
- # nb, nc, nt
94
- if isinstance(x, AudioSignal):
95
- references = x.audio_data
96
- estimates = y.audio_data
97
- else:
98
- references = x
99
- estimates = y
100
-
101
- nb = references.shape[0]
102
- references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
- estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
-
105
- # samples now on axis 1
106
- if self.zero_mean:
107
- mean_reference = references.mean(dim=1, keepdim=True)
108
- mean_estimate = estimates.mean(dim=1, keepdim=True)
109
- else:
110
- mean_reference = 0
111
- mean_estimate = 0
112
-
113
- _references = references - mean_reference
114
- _estimates = estimates - mean_estimate
115
-
116
- references_projection = (_references**2).sum(dim=-2) + eps
117
- references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
-
119
- scale = (
120
- (references_on_estimates / references_projection).unsqueeze(1)
121
- if self.scaling
122
- else 1
123
- )
124
-
125
- e_true = scale * _references
126
- e_res = _estimates - e_true
127
-
128
- signal = (e_true**2).sum(dim=1)
129
- noise = (e_res**2).sum(dim=1)
130
- sdr = -10 * torch.log10(signal / noise + eps)
131
-
132
- if self.clip_min is not None:
133
- sdr = torch.clamp(sdr, min=self.clip_min)
134
-
135
- if self.reduction == "mean":
136
- sdr = sdr.mean()
137
- elif self.reduction == "sum":
138
- sdr = sdr.sum()
139
- return sdr
140
-
141
-
142
- class MultiScaleSTFTLoss(nn.Module):
143
- """Computes the multi-scale STFT loss from [1].
144
-
145
- Parameters
146
- ----------
147
- window_lengths : List[int], optional
148
- Length of each window of each STFT, by default [2048, 512]
149
- loss_fn : typing.Callable, optional
150
- How to compare each loss, by default nn.L1Loss()
151
- clamp_eps : float, optional
152
- Clamp on the log magnitude, below, by default 1e-5
153
- mag_weight : float, optional
154
- Weight of raw magnitude portion of loss, by default 1.0
155
- log_weight : float, optional
156
- Weight of log magnitude portion of loss, by default 1.0
157
- pow : float, optional
158
- Power to raise magnitude to before taking log, by default 2.0
159
- weight : float, optional
160
- Weight of this loss, by default 1.0
161
- match_stride : bool, optional
162
- Whether to match the stride of convolutional layers, by default False
163
-
164
- References
165
- ----------
166
-
167
- 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
- "DDSP: Differentiable Digital Signal Processing."
169
- International Conference on Learning Representations. 2019.
170
-
171
- Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
- """
173
-
174
- def __init__(
175
- self,
176
- window_lengths: List[int] = [2048, 512],
177
- loss_fn: typing.Callable = nn.L1Loss(),
178
- clamp_eps: float = 1e-5,
179
- mag_weight: float = 1.0,
180
- log_weight: float = 1.0,
181
- pow: float = 2.0,
182
- weight: float = 1.0,
183
- match_stride: bool = False,
184
- window_type: str = None,
185
- ):
186
- super().__init__()
187
- self.stft_params = [
188
- STFTParams(
189
- window_length=w,
190
- hop_length=w // 4,
191
- match_stride=match_stride,
192
- window_type=window_type,
193
- )
194
- for w in window_lengths
195
- ]
196
- self.loss_fn = loss_fn
197
- self.log_weight = log_weight
198
- self.mag_weight = mag_weight
199
- self.clamp_eps = clamp_eps
200
- self.weight = weight
201
- self.pow = pow
202
-
203
- def forward(self, x: AudioSignal, y: AudioSignal):
204
- """Computes multi-scale STFT between an estimate and a reference
205
- signal.
206
-
207
- Parameters
208
- ----------
209
- x : AudioSignal
210
- Estimate signal
211
- y : AudioSignal
212
- Reference signal
213
-
214
- Returns
215
- -------
216
- torch.Tensor
217
- Multi-scale STFT loss.
218
- """
219
- loss = 0.0
220
- for s in self.stft_params:
221
- x.stft(s.window_length, s.hop_length, s.window_type)
222
- y.stft(s.window_length, s.hop_length, s.window_type)
223
- loss += self.log_weight * self.loss_fn(
224
- x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
- y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
- )
227
- loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
- return loss
229
-
230
-
231
- class MelSpectrogramLoss(nn.Module):
232
- """Compute distance between mel spectrograms. Can be used
233
- in a multi-scale way.
234
-
235
- Parameters
236
- ----------
237
- n_mels : List[int]
238
- Number of mels per STFT, by default [150, 80],
239
- window_lengths : List[int], optional
240
- Length of each window of each STFT, by default [2048, 512]
241
- loss_fn : typing.Callable, optional
242
- How to compare each loss, by default nn.L1Loss()
243
- clamp_eps : float, optional
244
- Clamp on the log magnitude, below, by default 1e-5
245
- mag_weight : float, optional
246
- Weight of raw magnitude portion of loss, by default 1.0
247
- log_weight : float, optional
248
- Weight of log magnitude portion of loss, by default 1.0
249
- pow : float, optional
250
- Power to raise magnitude to before taking log, by default 2.0
251
- weight : float, optional
252
- Weight of this loss, by default 1.0
253
- match_stride : bool, optional
254
- Whether to match the stride of convolutional layers, by default False
255
-
256
- Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
- """
258
-
259
- def __init__(
260
- self,
261
- n_mels: List[int] = [150, 80],
262
- window_lengths: List[int] = [2048, 512],
263
- loss_fn: typing.Callable = nn.L1Loss(),
264
- clamp_eps: float = 1e-5,
265
- mag_weight: float = 1.0,
266
- log_weight: float = 1.0,
267
- pow: float = 2.0,
268
- weight: float = 1.0,
269
- match_stride: bool = False,
270
- mel_fmin: List[float] = [0.0, 0.0],
271
- mel_fmax: List[float] = [None, None],
272
- window_type: str = None,
273
- ):
274
- super().__init__()
275
- self.stft_params = [
276
- STFTParams(
277
- window_length=w,
278
- hop_length=w // 4,
279
- match_stride=match_stride,
280
- window_type=window_type,
281
- )
282
- for w in window_lengths
283
- ]
284
- self.n_mels = n_mels
285
- self.loss_fn = loss_fn
286
- self.clamp_eps = clamp_eps
287
- self.log_weight = log_weight
288
- self.mag_weight = mag_weight
289
- self.weight = weight
290
- self.mel_fmin = mel_fmin
291
- self.mel_fmax = mel_fmax
292
- self.pow = pow
293
-
294
- def forward(self, x: AudioSignal, y: AudioSignal):
295
- """Computes mel loss between an estimate and a reference
296
- signal.
297
-
298
- Parameters
299
- ----------
300
- x : AudioSignal
301
- Estimate signal
302
- y : AudioSignal
303
- Reference signal
304
-
305
- Returns
306
- -------
307
- torch.Tensor
308
- Mel loss.
309
- """
310
- loss = 0.0
311
- for n_mels, fmin, fmax, s in zip(
312
- self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
- ):
314
- kwargs = {
315
- "window_length": s.window_length,
316
- "hop_length": s.hop_length,
317
- "window_type": s.window_type,
318
- }
319
- x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
- y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
-
322
- loss += self.log_weight * self.loss_fn(
323
- x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
- y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
- )
326
- loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
- return loss
328
-
329
-
330
- class GANLoss(nn.Module):
331
- """
332
- Computes a discriminator loss, given a discriminator on
333
- generated waveforms/spectrograms compared to ground truth
334
- waveforms/spectrograms. Computes the loss for both the
335
- discriminator and the generator in separate functions.
336
- """
337
-
338
- def __init__(self, discriminator):
339
- super().__init__()
340
- self.discriminator = discriminator
341
-
342
- def forward(self, fake, real):
343
- d_fake = self.discriminator(fake.audio_data)
344
- d_real = self.discriminator(real.audio_data)
345
- return d_fake, d_real
346
-
347
- def discriminator_loss(self, fake, real):
348
- d_fake, d_real = self.forward(fake.clone().detach(), real)
349
-
350
- loss_d = 0
351
- for x_fake, x_real in zip(d_fake, d_real):
352
- loss_d += torch.mean(x_fake[-1] ** 2)
353
- loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
- return loss_d
355
-
356
- def generator_loss(self, fake, real):
357
- d_fake, d_real = self.forward(fake, real)
358
-
359
- loss_g = 0
360
- for x_fake in d_fake:
361
- loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
-
363
- loss_feature = 0
364
-
365
- for i in range(len(d_fake)):
366
- for j in range(len(d_fake[i]) - 1):
367
- loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
- return loss_g, loss_feature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/nn/quantize.py DELETED
@@ -1,339 +0,0 @@
1
- from typing import Union
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from einops import rearrange
8
- from torch.nn.utils import weight_norm
9
-
10
- from dac.nn.layers import WNConv1d
11
-
12
- class VectorQuantizeLegacy(nn.Module):
13
- """
14
- Implementation of VQ similar to Karpathy's repo:
15
- https://github.com/karpathy/deep-vector-quantization
16
- removed in-out projection
17
- """
18
-
19
- def __init__(self, input_dim: int, codebook_size: int):
20
- super().__init__()
21
- self.codebook_size = codebook_size
22
- self.codebook = nn.Embedding(codebook_size, input_dim)
23
-
24
- def forward(self, z, z_mask=None):
25
- """Quantized the input tensor using a fixed codebook and returns
26
- the corresponding codebook vectors
27
-
28
- Parameters
29
- ----------
30
- z : Tensor[B x D x T]
31
-
32
- Returns
33
- -------
34
- Tensor[B x D x T]
35
- Quantized continuous representation of input
36
- Tensor[1]
37
- Commitment loss to train encoder to predict vectors closer to codebook
38
- entries
39
- Tensor[1]
40
- Codebook loss to update the codebook
41
- Tensor[B x T]
42
- Codebook indices (quantized discrete representation of input)
43
- Tensor[B x D x T]
44
- Projected latents (continuous representation of input before quantization)
45
- """
46
-
47
- z_e = z
48
- z_q, indices = self.decode_latents(z)
49
-
50
- if z_mask is not None:
51
- commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
52
- codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
53
- else:
54
- commitment_loss = F.mse_loss(z_e, z_q.detach())
55
- codebook_loss = F.mse_loss(z_q, z_e.detach())
56
- z_q = (
57
- z_e + (z_q - z_e).detach()
58
- ) # noop in forward pass, straight-through gradient estimator in backward pass
59
-
60
- return z_q, indices, z_e, commitment_loss, codebook_loss
61
-
62
- def embed_code(self, embed_id):
63
- return F.embedding(embed_id, self.codebook.weight)
64
-
65
- def decode_code(self, embed_id):
66
- return self.embed_code(embed_id).transpose(1, 2)
67
-
68
- def decode_latents(self, latents):
69
- encodings = rearrange(latents, "b d t -> (b t) d")
70
- codebook = self.codebook.weight # codebook: (N x D)
71
-
72
- # L2 normalize encodings and codebook (ViT-VQGAN)
73
- encodings = F.normalize(encodings)
74
- codebook = F.normalize(codebook)
75
-
76
- # Compute euclidean distance with codebook
77
- dist = (
78
- encodings.pow(2).sum(1, keepdim=True)
79
- - 2 * encodings @ codebook.t()
80
- + codebook.pow(2).sum(1, keepdim=True).t()
81
- )
82
- indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
83
- z_q = self.decode_code(indices)
84
- return z_q, indices
85
-
86
- class VectorQuantize(nn.Module):
87
- """
88
- Implementation of VQ similar to Karpathy's repo:
89
- https://github.com/karpathy/deep-vector-quantization
90
- Additionally uses following tricks from Improved VQGAN
91
- (https://arxiv.org/pdf/2110.04627.pdf):
92
- 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
93
- for improved codebook usage
94
- 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
95
- improves training stability
96
- """
97
-
98
- def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
99
- super().__init__()
100
- self.codebook_size = codebook_size
101
- self.codebook_dim = codebook_dim
102
-
103
- self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
104
- self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
105
- self.codebook = nn.Embedding(codebook_size, codebook_dim)
106
-
107
- def forward(self, z, z_mask=None):
108
- """Quantized the input tensor using a fixed codebook and returns
109
- the corresponding codebook vectors
110
-
111
- Parameters
112
- ----------
113
- z : Tensor[B x D x T]
114
-
115
- Returns
116
- -------
117
- Tensor[B x D x T]
118
- Quantized continuous representation of input
119
- Tensor[1]
120
- Commitment loss to train encoder to predict vectors closer to codebook
121
- entries
122
- Tensor[1]
123
- Codebook loss to update the codebook
124
- Tensor[B x T]
125
- Codebook indices (quantized discrete representation of input)
126
- Tensor[B x D x T]
127
- Projected latents (continuous representation of input before quantization)
128
- """
129
-
130
- # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
131
- z_e = self.in_proj(z) # z_e : (B x D x T)
132
- z_q, indices = self.decode_latents(z_e)
133
-
134
- if z_mask is not None:
135
- commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
136
- codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
137
- else:
138
- commitment_loss = F.mse_loss(z_e, z_q.detach())
139
- codebook_loss = F.mse_loss(z_q, z_e.detach())
140
-
141
- z_q = (
142
- z_e + (z_q - z_e).detach()
143
- ) # noop in forward pass, straight-through gradient estimator in backward pass
144
-
145
- z_q = self.out_proj(z_q)
146
-
147
- return z_q, commitment_loss, codebook_loss, indices, z_e
148
-
149
- def embed_code(self, embed_id):
150
- return F.embedding(embed_id, self.codebook.weight)
151
-
152
- def decode_code(self, embed_id):
153
- return self.embed_code(embed_id).transpose(1, 2)
154
-
155
- def decode_latents(self, latents):
156
- encodings = rearrange(latents, "b d t -> (b t) d")
157
- codebook = self.codebook.weight # codebook: (N x D)
158
-
159
- # L2 normalize encodings and codebook (ViT-VQGAN)
160
- encodings = F.normalize(encodings)
161
- codebook = F.normalize(codebook)
162
-
163
- # Compute euclidean distance with codebook
164
- dist = (
165
- encodings.pow(2).sum(1, keepdim=True)
166
- - 2 * encodings @ codebook.t()
167
- + codebook.pow(2).sum(1, keepdim=True).t()
168
- )
169
- indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
170
- z_q = self.decode_code(indices)
171
- return z_q, indices
172
-
173
-
174
- class ResidualVectorQuantize(nn.Module):
175
- """
176
- Introduced in SoundStream: An end2end neural audio codec
177
- https://arxiv.org/abs/2107.03312
178
- """
179
-
180
- def __init__(
181
- self,
182
- input_dim: int = 512,
183
- n_codebooks: int = 9,
184
- codebook_size: int = 1024,
185
- codebook_dim: Union[int, list] = 8,
186
- quantizer_dropout: float = 0.0,
187
- ):
188
- super().__init__()
189
- if isinstance(codebook_dim, int):
190
- codebook_dim = [codebook_dim for _ in range(n_codebooks)]
191
-
192
- self.n_codebooks = n_codebooks
193
- self.codebook_dim = codebook_dim
194
- self.codebook_size = codebook_size
195
-
196
- self.quantizers = nn.ModuleList(
197
- [
198
- VectorQuantize(input_dim, codebook_size, codebook_dim[i])
199
- for i in range(n_codebooks)
200
- ]
201
- )
202
- self.quantizer_dropout = quantizer_dropout
203
-
204
- def forward(self, z, n_quantizers: int = None):
205
- """Quantized the input tensor using a fixed set of `n` codebooks and returns
206
- the corresponding codebook vectors
207
- Parameters
208
- ----------
209
- z : Tensor[B x D x T]
210
- n_quantizers : int, optional
211
- No. of quantizers to use
212
- (n_quantizers < self.n_codebooks ex: for quantizer dropout)
213
- Note: if `self.quantizer_dropout` is True, this argument is ignored
214
- when in training mode, and a random number of quantizers is used.
215
- Returns
216
- -------
217
- dict
218
- A dictionary with the following keys:
219
-
220
- "z" : Tensor[B x D x T]
221
- Quantized continuous representation of input
222
- "codes" : Tensor[B x N x T]
223
- Codebook indices for each codebook
224
- (quantized discrete representation of input)
225
- "latents" : Tensor[B x N*D x T]
226
- Projected latents (continuous representation of input before quantization)
227
- "vq/commitment_loss" : Tensor[1]
228
- Commitment loss to train encoder to predict vectors closer to codebook
229
- entries
230
- "vq/codebook_loss" : Tensor[1]
231
- Codebook loss to update the codebook
232
- """
233
- z_q = 0
234
- residual = z
235
- commitment_loss = 0
236
- codebook_loss = 0
237
-
238
- codebook_indices = []
239
- latents = []
240
-
241
- if n_quantizers is None:
242
- n_quantizers = self.n_codebooks
243
- if self.training:
244
- n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
245
- dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
246
- n_dropout = int(z.shape[0] * self.quantizer_dropout)
247
- n_quantizers[:n_dropout] = dropout[:n_dropout]
248
- n_quantizers = n_quantizers.to(z.device)
249
-
250
- for i, quantizer in enumerate(self.quantizers):
251
- if self.training is False and i >= n_quantizers:
252
- break
253
-
254
- z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
255
- residual
256
- )
257
-
258
- # Create mask to apply quantizer dropout
259
- mask = (
260
- torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
261
- )
262
- z_q = z_q + z_q_i * mask[:, None, None]
263
- residual = residual - z_q_i
264
-
265
- # Sum losses
266
- commitment_loss += (commitment_loss_i * mask).mean()
267
- codebook_loss += (codebook_loss_i * mask).mean()
268
-
269
- codebook_indices.append(indices_i)
270
- latents.append(z_e_i)
271
-
272
- codes = torch.stack(codebook_indices, dim=1)
273
- latents = torch.cat(latents, dim=1)
274
-
275
- return z_q, codes, latents, commitment_loss, codebook_loss
276
-
277
- def from_codes(self, codes: torch.Tensor):
278
- """Given the quantized codes, reconstruct the continuous representation
279
- Parameters
280
- ----------
281
- codes : Tensor[B x N x T]
282
- Quantized discrete representation of input
283
- Returns
284
- -------
285
- Tensor[B x D x T]
286
- Quantized continuous representation of input
287
- """
288
- z_q = 0.0
289
- z_p = []
290
- n_codebooks = codes.shape[1]
291
- for i in range(n_codebooks):
292
- z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
293
- z_p.append(z_p_i)
294
-
295
- z_q_i = self.quantizers[i].out_proj(z_p_i)
296
- z_q = z_q + z_q_i
297
- return z_q, torch.cat(z_p, dim=1), codes
298
-
299
- def from_latents(self, latents: torch.Tensor):
300
- """Given the unquantized latents, reconstruct the
301
- continuous representation after quantization.
302
-
303
- Parameters
304
- ----------
305
- latents : Tensor[B x N x T]
306
- Continuous representation of input after projection
307
-
308
- Returns
309
- -------
310
- Tensor[B x D x T]
311
- Quantized representation of full-projected space
312
- Tensor[B x D x T]
313
- Quantized representation of latent space
314
- """
315
- z_q = 0
316
- z_p = []
317
- codes = []
318
- dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
319
-
320
- n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
321
- 0
322
- ]
323
- for i in range(n_codebooks):
324
- j, k = dims[i], dims[i + 1]
325
- z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
326
- z_p.append(z_p_i)
327
- codes.append(codes_i)
328
-
329
- z_q_i = self.quantizers[i].out_proj(z_p_i)
330
- z_q = z_q + z_q_i
331
-
332
- return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
333
-
334
-
335
- if __name__ == "__main__":
336
- rvq = ResidualVectorQuantize(quantizer_dropout=True)
337
- x = torch.randn(16, 512, 80)
338
- y = rvq(x)
339
- print(y["latents"].shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/utils/__init__.py DELETED
@@ -1,123 +0,0 @@
1
- from pathlib import Path
2
-
3
- import argbind
4
- from audiotools import ml
5
-
6
- import dac
7
-
8
- DAC = dac.model.DAC
9
- Accelerator = ml.Accelerator
10
-
11
- __MODEL_LATEST_TAGS__ = {
12
- ("44khz", "8kbps"): "0.0.1",
13
- ("24khz", "8kbps"): "0.0.4",
14
- ("16khz", "8kbps"): "0.0.5",
15
- ("44khz", "16kbps"): "1.0.0",
16
- }
17
-
18
- __MODEL_URLS__ = {
19
- (
20
- "44khz",
21
- "0.0.1",
22
- "8kbps",
23
- ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
24
- (
25
- "24khz",
26
- "0.0.4",
27
- "8kbps",
28
- ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
29
- (
30
- "16khz",
31
- "0.0.5",
32
- "8kbps",
33
- ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
34
- (
35
- "44khz",
36
- "1.0.0",
37
- "16kbps",
38
- ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
39
- }
40
-
41
-
42
- @argbind.bind(group="download", positional=True, without_prefix=True)
43
- def download(
44
- model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
45
- ):
46
- """
47
- Function that downloads the weights file from URL if a local cache is not found.
48
-
49
- Parameters
50
- ----------
51
- model_type : str
52
- The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
53
- model_bitrate: str
54
- Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
55
- Only 44khz model supports 16kbps.
56
- tag : str
57
- The tag of the model to download. Defaults to "latest".
58
-
59
- Returns
60
- -------
61
- Path
62
- Directory path required to load model via audiotools.
63
- """
64
- model_type = model_type.lower()
65
- tag = tag.lower()
66
-
67
- assert model_type in [
68
- "44khz",
69
- "24khz",
70
- "16khz",
71
- ], "model_type must be one of '44khz', '24khz', or '16khz'"
72
-
73
- assert model_bitrate in [
74
- "8kbps",
75
- "16kbps",
76
- ], "model_bitrate must be one of '8kbps', or '16kbps'"
77
-
78
- if tag == "latest":
79
- tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
80
-
81
- download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
82
-
83
- if download_link is None:
84
- raise ValueError(
85
- f"Could not find model with tag {tag} and model type {model_type}"
86
- )
87
-
88
- local_path = (
89
- Path.home()
90
- / ".cache"
91
- / "descript"
92
- / "dac"
93
- / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
94
- )
95
- if not local_path.exists():
96
- local_path.parent.mkdir(parents=True, exist_ok=True)
97
-
98
- # Download the model
99
- import requests
100
-
101
- response = requests.get(download_link)
102
-
103
- if response.status_code != 200:
104
- raise ValueError(
105
- f"Could not download model. Received response code {response.status_code}"
106
- )
107
- local_path.write_bytes(response.content)
108
-
109
- return local_path
110
-
111
-
112
- def load_model(
113
- model_type: str = "44khz",
114
- model_bitrate: str = "8kbps",
115
- tag: str = "latest",
116
- load_path: str = None,
117
- ):
118
- if not load_path:
119
- load_path = download(
120
- model_type=model_type, model_bitrate=model_bitrate, tag=tag
121
- )
122
- generator = DAC.load(load_path)
123
- return generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/utils/decode.py DELETED
@@ -1,95 +0,0 @@
1
- import warnings
2
- from pathlib import Path
3
-
4
- import argbind
5
- import numpy as np
6
- import torch
7
- from audiotools import AudioSignal
8
- from tqdm import tqdm
9
-
10
- from dac import DACFile
11
- from dac.utils import load_model
12
-
13
- warnings.filterwarnings("ignore", category=UserWarning)
14
-
15
-
16
- @argbind.bind(group="decode", positional=True, without_prefix=True)
17
- @torch.inference_mode()
18
- @torch.no_grad()
19
- def decode(
20
- input: str,
21
- output: str = "",
22
- weights_path: str = "",
23
- model_tag: str = "latest",
24
- model_bitrate: str = "8kbps",
25
- device: str = "cuda",
26
- model_type: str = "44khz",
27
- verbose: bool = False,
28
- ):
29
- """Decode audio from codes.
30
-
31
- Parameters
32
- ----------
33
- input : str
34
- Path to input directory or file
35
- output : str, optional
36
- Path to output directory, by default "".
37
- If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38
- weights_path : str, optional
39
- Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40
- model_tag and model_type.
41
- model_tag : str, optional
42
- Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43
- model_bitrate: str
44
- Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45
- device : str, optional
46
- Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47
- model_type : str, optional
48
- The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49
- """
50
- generator = load_model(
51
- model_type=model_type,
52
- model_bitrate=model_bitrate,
53
- tag=model_tag,
54
- load_path=weights_path,
55
- )
56
- generator.to(device)
57
- generator.eval()
58
-
59
- # Find all .dac files in input directory
60
- _input = Path(input)
61
- input_files = list(_input.glob("**/*.dac"))
62
-
63
- # If input is a .dac file, add it to the list
64
- if _input.suffix == ".dac":
65
- input_files.append(_input)
66
-
67
- # Create output directory
68
- output = Path(output)
69
- output.mkdir(parents=True, exist_ok=True)
70
-
71
- for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72
- # Load file
73
- artifact = DACFile.load(input_files[i])
74
-
75
- # Reconstruct audio from codes
76
- recons = generator.decompress(artifact, verbose=verbose)
77
-
78
- # Compute output path
79
- relative_path = input_files[i].relative_to(input)
80
- output_dir = output / relative_path.parent
81
- if not relative_path.name:
82
- output_dir = output
83
- relative_path = input_files[i]
84
- output_name = relative_path.with_suffix(".wav").name
85
- output_path = output_dir / output_name
86
- output_path.parent.mkdir(parents=True, exist_ok=True)
87
-
88
- # Write to file
89
- recons.write(output_path)
90
-
91
-
92
- if __name__ == "__main__":
93
- args = argbind.parse_args()
94
- with argbind.scope(args):
95
- decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac/utils/encode.py DELETED
@@ -1,94 +0,0 @@
1
- import math
2
- import warnings
3
- from pathlib import Path
4
-
5
- import argbind
6
- import numpy as np
7
- import torch
8
- from audiotools import AudioSignal
9
- from audiotools.core import util
10
- from tqdm import tqdm
11
-
12
- from dac.utils import load_model
13
-
14
- warnings.filterwarnings("ignore", category=UserWarning)
15
-
16
-
17
- @argbind.bind(group="encode", positional=True, without_prefix=True)
18
- @torch.inference_mode()
19
- @torch.no_grad()
20
- def encode(
21
- input: str,
22
- output: str = "",
23
- weights_path: str = "",
24
- model_tag: str = "latest",
25
- model_bitrate: str = "8kbps",
26
- n_quantizers: int = None,
27
- device: str = "cuda",
28
- model_type: str = "44khz",
29
- win_duration: float = 5.0,
30
- verbose: bool = False,
31
- ):
32
- """Encode audio files in input path to .dac format.
33
-
34
- Parameters
35
- ----------
36
- input : str
37
- Path to input audio file or directory
38
- output : str, optional
39
- Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40
- weights_path : str, optional
41
- Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42
- model_tag and model_type.
43
- model_tag : str, optional
44
- Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45
- model_bitrate: str
46
- Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47
- n_quantizers : int, optional
48
- Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49
- device : str, optional
50
- Device to use, by default "cuda"
51
- model_type : str, optional
52
- The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53
- """
54
- generator = load_model(
55
- model_type=model_type,
56
- model_bitrate=model_bitrate,
57
- tag=model_tag,
58
- load_path=weights_path,
59
- )
60
- generator.to(device)
61
- generator.eval()
62
- kwargs = {"n_quantizers": n_quantizers}
63
-
64
- # Find all audio files in input path
65
- input = Path(input)
66
- audio_files = util.find_audio(input)
67
-
68
- output = Path(output)
69
- output.mkdir(parents=True, exist_ok=True)
70
-
71
- for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72
- # Load file
73
- signal = AudioSignal(audio_files[i])
74
-
75
- # Encode audio to .dac format
76
- artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77
-
78
- # Compute output path
79
- relative_path = audio_files[i].relative_to(input)
80
- output_dir = output / relative_path.parent
81
- if not relative_path.name:
82
- output_dir = output
83
- relative_path = audio_files[i]
84
- output_name = relative_path.with_suffix(".dac").name
85
- output_path = output_dir / output_name
86
- output_path.parent.mkdir(parents=True, exist_ok=True)
87
-
88
- artifact.save(output_path)
89
-
90
-
91
- if __name__ == "__main__":
92
- args = argbind.parse_args()
93
- with argbind.scope(args):
94
- encode()