Dionyssos commited on
Commit
fe62fb4
·
1 Parent(s): e366cd5

cleanup Vq

Browse files
audiocraft/builders.py CHANGED
@@ -4,15 +4,9 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- """
8
- All the functions to build the relevant models and modules
9
- from the Hydra config.
10
- """
11
-
12
  import typing as tp
13
  import omegaconf
14
  import torch
15
-
16
  from .encodec import CompressionModel, EncodecModel
17
  from .lm import LMModel
18
  from .seanet import SEANetDecoder
@@ -24,15 +18,15 @@ from .conditioners import (
24
  T5Conditioner,
25
  )
26
  from .unet import DiffusionUnet
27
- import audiocraft.quantization as qt
28
  from .utils.utils import dict_from_config
29
  from .diffusion_schedule import MultiBandProcessor, SampleProcessor
30
 
31
 
32
- def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
33
  klass = {
34
- 'no_quant': qt.DummyQuantizer,
35
- 'rvq': qt.ResidualVectorQuantizer
36
  }[quantizer]
37
  kwargs = dict_from_config(getattr(cfg, quantizer))
38
  if quantizer != 'no_quant':
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
 
 
 
 
7
  import typing as tp
8
  import omegaconf
9
  import torch
 
10
  from .encodec import CompressionModel, EncodecModel
11
  from .lm import LMModel
12
  from .seanet import SEANetDecoder
 
18
  T5Conditioner,
19
  )
20
  from .unet import DiffusionUnet
21
+ from .vq import ResidualVectorQuantizer
22
  from .utils.utils import dict_from_config
23
  from .diffusion_schedule import MultiBandProcessor, SampleProcessor
24
 
25
 
26
+ def get_quantizer(quantizer, cfg, dimension):
27
  klass = {
28
+ 'no_quant': None,
29
+ 'rvq': ResidualVectorQuantizer
30
  }[quantizer]
31
  kwargs = dict_from_config(getattr(cfg, quantizer))
32
  if quantizer != 'no_quant':
audiocraft/encodec.py CHANGED
@@ -9,7 +9,6 @@ Also defines the main interface that a model must follow to be usable as an audi
9
 
10
  from abc import ABC, abstractmethod
11
  import logging
12
- import math
13
  from pathlib import Path
14
  import typing as tp
15
 
@@ -19,8 +18,6 @@ import torch
19
  from torch import nn
20
  from transformers import EncodecModel as HFEncodecModel
21
 
22
- import audiocraft.quantization as qt
23
-
24
 
25
  logger = logging.getLogger()
26
 
 
9
 
10
  from abc import ABC, abstractmethod
11
  import logging
 
12
  from pathlib import Path
13
  import typing as tp
14
 
 
18
  from torch import nn
19
  from transformers import EncodecModel as HFEncodecModel
20
 
 
 
21
 
22
  logger = logging.getLogger()
23
 
audiocraft/lm.py CHANGED
@@ -433,7 +433,9 @@ class LMModel(StreamingModule):
433
 
434
  # print(f'{unconditional_state=} \n
435
  # print('Set All to Special')
436
- # next_token[:] = self.special_token_id
 
 
437
 
438
 
439
 
@@ -449,7 +451,7 @@ class LMModel(StreamingModule):
449
  unconditional_state.clear()
450
 
451
  out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
452
-
453
  out_start_offset = start_offset if remove_prompts else 0
454
  out_codes = out_codes[..., out_start_offset:max_gen_len]
455
 
 
433
 
434
  # print(f'{unconditional_state=} \n
435
  # print('Set All to Special')
436
+
437
+ # RUNS with = 2047 just different of self.special_token_id -> 2047 is drill noise
438
+ # next_token[:] = self.special_token_id
439
 
440
 
441
 
 
451
  unconditional_state.clear()
452
 
453
  out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
454
+ print(f'{out_codes.shape=} {out_codes.min()} {out_codes.max()}\n')
455
  out_start_offset = start_offset if remove_prompts else 0
456
  out_codes = out_codes[..., out_start_offset:max_gen_len]
457
 
audiocraft/quantization/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """RVQ."""
7
- # flake8: noqa
8
- from .vq import ResidualVectorQuantizer
9
- from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
 
 
 
 
 
 
 
 
 
 
audiocraft/quantization/base.py DELETED
@@ -1,99 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Base class for all quantizers.
9
- """
10
-
11
- from dataclasses import dataclass, field
12
- import typing as tp
13
-
14
- import torch
15
- from torch import nn
16
-
17
-
18
- @dataclass
19
- class QuantizedResult:
20
- x: torch.Tensor
21
- codes: torch.Tensor
22
- bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
23
- penalty: tp.Optional[torch.Tensor] = None
24
- metrics: dict = field(default_factory=dict)
25
-
26
-
27
- class BaseQuantizer(nn.Module):
28
- """Base class for quantizers.
29
- """
30
-
31
- def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
32
- """
33
- Given input tensor x, returns first the quantized (or approximately quantized)
34
- representation along with quantized codes, bandwidth, and any penalty term for the loss.
35
- Finally, this returns a dict of metrics to update logging etc.
36
- Frame rate must be passed so that the bandwidth is properly computed.
37
- """
38
- raise NotImplementedError()
39
-
40
- def encode(self, x: torch.Tensor) -> torch.Tensor:
41
- """Encode a given input tensor with the specified sample rate at the given bandwidth."""
42
- raise NotImplementedError()
43
-
44
- def decode(self, codes: torch.Tensor) -> torch.Tensor:
45
- """Decode the given codes to the quantized representation."""
46
- raise NotImplementedError()
47
-
48
- @property
49
- def total_codebooks(self):
50
- """Total number of codebooks."""
51
- raise NotImplementedError()
52
-
53
- @property
54
- def num_codebooks(self):
55
- """Number of active codebooks."""
56
- raise NotImplementedError()
57
-
58
- def set_num_codebooks(self, n: int):
59
- """Set the number of active codebooks."""
60
- raise NotImplementedError()
61
-
62
-
63
- class DummyQuantizer(BaseQuantizer):
64
- """Fake quantizer that actually does not perform any quantization.
65
- """
66
- def __init__(self):
67
- super().__init__()
68
-
69
- def forward(self, x: torch.Tensor, frame_rate: int):
70
- q = x.unsqueeze(1)
71
- return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
72
-
73
- def encode(self, x: torch.Tensor) -> torch.Tensor:
74
- """Encode a given input tensor with the specified sample rate at the given bandwidth.
75
- In the case of the DummyQuantizer, the codes are actually identical
76
- to the input and resulting quantized representation as no quantization is done.
77
- """
78
- return x.unsqueeze(1)
79
-
80
- def decode(self, codes: torch.Tensor) -> torch.Tensor:
81
- """Decode the given codes to the quantized representation.
82
- In the case of the DummyQuantizer, the codes are actually identical
83
- to the input and resulting quantized representation as no quantization is done.
84
- """
85
- return codes.squeeze(1)
86
-
87
- @property
88
- def total_codebooks(self):
89
- """Total number of codebooks."""
90
- return 1
91
-
92
- @property
93
- def num_codebooks(self):
94
- """Total number of codebooks."""
95
- return self.total_codebooks
96
-
97
- def set_num_codebooks(self, n: int):
98
- """Set the number of active codebooks."""
99
- raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/quantization/core_vq.py DELETED
@@ -1,405 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import typing as tp
8
-
9
- from einops import rearrange, repeat
10
- import flashy
11
- import torch
12
- from torch import nn, einsum
13
- import torch.nn.functional as F
14
-
15
-
16
- def exists(val: tp.Optional[tp.Any]) -> bool:
17
- return val is not None
18
-
19
-
20
- def default(val: tp.Any, d: tp.Any) -> tp.Any:
21
- return val if exists(val) else d
22
-
23
-
24
- def l2norm(t):
25
- return F.normalize(t, p=2, dim=-1)
26
-
27
-
28
- def ema_inplace(moving_avg, new, decay: float):
29
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
30
-
31
-
32
- def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
33
- return (x + epsilon) / (x.sum() + n_categories * epsilon)
34
-
35
-
36
- def uniform_init(*shape: int):
37
- t = torch.empty(shape)
38
- nn.init.kaiming_uniform_(t)
39
- return t
40
-
41
-
42
- def sample_vectors(samples, num: int):
43
- num_samples, device = samples.shape[0], samples.device
44
-
45
- if num_samples >= num:
46
- indices = torch.randperm(num_samples, device=device)[:num]
47
- else:
48
- indices = torch.randint(0, num_samples, (num,), device=device)
49
-
50
- return samples[indices]
51
-
52
-
53
- def kmeans(samples, num_clusters: int, num_iters: int = 10):
54
- dim, dtype = samples.shape[-1], samples.dtype
55
-
56
- means = sample_vectors(samples, num_clusters)
57
-
58
- for _ in range(num_iters):
59
- diffs = rearrange(samples, "n d -> n () d") - rearrange(
60
- means, "c d -> () c d"
61
- )
62
- dists = -(diffs ** 2).sum(dim=-1)
63
-
64
- buckets = dists.max(dim=-1).indices
65
- bins = torch.bincount(buckets, minlength=num_clusters)
66
- zero_mask = bins == 0
67
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
68
-
69
- new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
70
- new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
71
- new_means = new_means / bins_min_clamped[..., None]
72
-
73
- means = torch.where(zero_mask[..., None], means, new_means)
74
-
75
- return means, bins
76
-
77
-
78
- def orthogonal_loss_fn(t):
79
- # eq (2) from https://arxiv.org/abs/2112.00384
80
- n = t.shape[0]
81
- normed_codes = l2norm(t)
82
- identity = torch.eye(n, device=t.device)
83
- cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
84
- return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
85
-
86
-
87
- class EuclideanCodebook(nn.Module):
88
- """Codebook with Euclidean distance.
89
-
90
- Args:
91
- dim (int): Dimension.
92
- codebook_size (int): Codebook size.
93
- kmeans_init (bool): Whether to use k-means to initialize the codebooks.
94
- If set to true, run the k-means algorithm on the first training batch and use
95
- the learned centroids as initialization.
96
- kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
97
- decay (float): Decay for exponential moving average over the codebooks.
98
- epsilon (float): Epsilon value for numerical stability.
99
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
100
- that have an exponential moving average cluster size less than the specified threshold with
101
- randomly selected vector from the current batch.
102
- """
103
- def __init__(
104
- self,
105
- dim: int,
106
- codebook_size: int,
107
- kmeans_init: int = False,
108
- kmeans_iters: int = 10,
109
- decay: float = 0.8,
110
- epsilon: float = 1e-5,
111
- threshold_ema_dead_code: int = 2,
112
- ):
113
- super().__init__()
114
- self.decay = decay
115
- init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
116
- embed = init_fn(codebook_size, dim)
117
-
118
- self.codebook_size = codebook_size
119
-
120
- self.kmeans_iters = kmeans_iters
121
- self.epsilon = epsilon
122
- self.threshold_ema_dead_code = threshold_ema_dead_code
123
-
124
- self.register_buffer("inited", torch.Tensor([not kmeans_init]))
125
- self.register_buffer("cluster_size", torch.zeros(codebook_size))
126
- self.register_buffer("embed", embed)
127
- self.register_buffer("embed_avg", embed.clone())
128
-
129
- @torch.jit.ignore
130
- def init_embed_(self, data):
131
- if self.inited:
132
- return
133
-
134
- embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
135
- self.embed.data.copy_(embed)
136
- self.embed_avg.data.copy_(embed.clone())
137
- self.cluster_size.data.copy_(cluster_size)
138
- self.inited.data.copy_(torch.Tensor([True]))
139
- # Make sure all buffers across workers are in sync after initialization
140
- flashy.distrib.broadcast_tensors(self.buffers())
141
-
142
- def replace_(self, samples, mask):
143
- modified_codebook = torch.where(
144
- mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
145
- )
146
- self.embed.data.copy_(modified_codebook)
147
-
148
- def expire_codes_(self, batch_samples):
149
- if self.threshold_ema_dead_code == 0:
150
- return
151
-
152
- expired_codes = self.cluster_size < self.threshold_ema_dead_code
153
- if not torch.any(expired_codes):
154
- return
155
-
156
- batch_samples = rearrange(batch_samples, "... d -> (...) d")
157
- self.replace_(batch_samples, mask=expired_codes)
158
- flashy.distrib.broadcast_tensors(self.buffers())
159
-
160
- def preprocess(self, x):
161
- x = rearrange(x, "... d -> (...) d")
162
- return x
163
-
164
- def quantize(self, x):
165
- embed = self.embed.t()
166
- dist = -(
167
- x.pow(2).sum(1, keepdim=True)
168
- - 2 * x @ embed
169
- + embed.pow(2).sum(0, keepdim=True)
170
- )
171
- embed_ind = dist.max(dim=-1).indices
172
- return embed_ind
173
-
174
- def postprocess_emb(self, embed_ind, shape):
175
- return embed_ind.view(*shape[:-1])
176
-
177
- def dequantize(self, embed_ind):
178
- quantize = F.embedding(embed_ind, self.embed)
179
- return quantize
180
-
181
- def encode(self, x):
182
- shape = x.shape
183
- # pre-process
184
- x = self.preprocess(x)
185
- # quantize
186
- embed_ind = self.quantize(x)
187
- # post-process
188
- embed_ind = self.postprocess_emb(embed_ind, shape)
189
- return embed_ind
190
-
191
- def decode(self, embed_ind):
192
- quantize = self.dequantize(embed_ind)
193
- return quantize
194
-
195
- def forward(self, x):
196
- shape, dtype = x.shape, x.dtype
197
- x = self.preprocess(x)
198
- self.init_embed_(x)
199
-
200
- embed_ind = self.quantize(x)
201
- embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
202
- embed_ind = self.postprocess_emb(embed_ind, shape)
203
- quantize = self.dequantize(embed_ind)
204
-
205
- if self.training:
206
- # We do the expiry of code at that point as buffers are in sync
207
- # and all the workers will take the same decision.
208
- self.expire_codes_(x)
209
- ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
210
- embed_sum = x.t() @ embed_onehot
211
- ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
212
- cluster_size = (
213
- laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
214
- * self.cluster_size.sum()
215
- )
216
- embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
217
- self.embed.data.copy_(embed_normalized)
218
-
219
- return quantize, embed_ind
220
-
221
-
222
- class VectorQuantization(nn.Module):
223
- """Vector quantization implementation.
224
- Currently supports only euclidean distance.
225
-
226
- Args:
227
- dim (int): Dimension
228
- codebook_size (int): Codebook size
229
- codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
230
- decay (float): Decay for exponential moving average over the codebooks.
231
- epsilon (float): Epsilon value for numerical stability.
232
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
233
- kmeans_iters (int): Number of iterations used for kmeans initialization.
234
- threshold_ema_dead_code (int):
235
- channels_last (bool): Channels are the last dimension in the input tensors.
236
- commitment_weight (float): Weight for commitment loss.
237
- orthogonal_reg_weight (float): Orthogonal regularization weights.
238
- orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
239
- orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
240
- for orthogonal regularization.
241
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
242
- that have an exponential moving average cluster size less than the specified threshold with
243
- randomly selected vector from the current batch.
244
- """
245
- def __init__(
246
- self,
247
- dim: int,
248
- codebook_size: int,
249
- codebook_dim: tp.Optional[int] = None,
250
- decay: float = 0.8,
251
- epsilon: float = 1e-5,
252
- kmeans_init: bool = False,
253
- kmeans_iters: int = 10,
254
- threshold_ema_dead_code: int = 2,
255
- channels_last: bool = False,
256
- commitment_weight: float = 1.,
257
- orthogonal_reg_weight: float = 0.0,
258
- orthogonal_reg_active_codes_only: bool = False,
259
- orthogonal_reg_max_codes: tp.Optional[int] = None,
260
- ):
261
- super().__init__()
262
- _codebook_dim: int = default(codebook_dim, dim)
263
-
264
- requires_projection = _codebook_dim != dim
265
- self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
266
- self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
267
-
268
- self.epsilon = epsilon
269
- self.commitment_weight = commitment_weight
270
-
271
- self.orthogonal_reg_weight = orthogonal_reg_weight
272
- self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
273
- self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
274
-
275
- self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
276
- kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
277
- decay=decay, epsilon=epsilon,
278
- threshold_ema_dead_code=threshold_ema_dead_code)
279
- self.codebook_size = codebook_size
280
-
281
- self.channels_last = channels_last
282
-
283
- @property
284
- def codebook(self):
285
- return self._codebook.embed
286
-
287
- @property
288
- def inited(self):
289
- return self._codebook.inited
290
-
291
- def _preprocess(self, x):
292
- if not self.channels_last:
293
- x = rearrange(x, "b d n -> b n d")
294
- return x
295
-
296
- def _postprocess(self, quantize):
297
- if not self.channels_last:
298
- quantize = rearrange(quantize, "b n d -> b d n")
299
- return quantize
300
-
301
- def encode(self, x):
302
- x = self._preprocess(x)
303
- x = self.project_in(x)
304
- embed_in = self._codebook.encode(x)
305
- return embed_in
306
-
307
- def decode(self, embed_ind):
308
- quantize = self._codebook.decode(embed_ind)
309
- quantize = self.project_out(quantize)
310
- quantize = self._postprocess(quantize)
311
- return quantize
312
-
313
- def forward(self, x):
314
- device = x.device
315
- x = self._preprocess(x)
316
-
317
- x = self.project_in(x)
318
- quantize, embed_ind = self._codebook(x)
319
-
320
- if self.training:
321
- quantize = x + (quantize - x).detach()
322
-
323
- loss = torch.tensor([0.0], device=device, requires_grad=self.training)
324
-
325
- if self.training:
326
- if self.commitment_weight > 0:
327
- commit_loss = F.mse_loss(quantize.detach(), x)
328
- loss = loss + commit_loss * self.commitment_weight
329
-
330
- if self.orthogonal_reg_weight > 0:
331
- codebook = self.codebook
332
-
333
- if self.orthogonal_reg_active_codes_only:
334
- # only calculate orthogonal loss for the activated codes for this batch
335
- unique_code_ids = torch.unique(embed_ind)
336
- codebook = codebook[unique_code_ids]
337
-
338
- num_codes = codebook.shape[0]
339
- if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
340
- rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
341
- codebook = codebook[rand_ids]
342
-
343
- orthogonal_reg_loss = orthogonal_loss_fn(codebook)
344
- loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
345
-
346
- quantize = self.project_out(quantize)
347
- quantize = self._postprocess(quantize)
348
-
349
- return quantize, embed_ind, loss
350
-
351
-
352
- class ResidualVectorQuantization(nn.Module):
353
- """Residual vector quantization implementation.
354
-
355
- Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
356
- """
357
- def __init__(self, *, num_quantizers, **kwargs):
358
- super().__init__()
359
- self.layers = nn.ModuleList(
360
- [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
361
- )
362
-
363
- def forward(self, x, n_q: tp.Optional[int] = None):
364
- quantized_out = 0.0
365
- residual = x
366
-
367
- all_losses = []
368
- all_indices = []
369
-
370
- n_q = n_q or len(self.layers)
371
-
372
- for i, layer in enumerate(self.layers[:n_q]):
373
- quantized, indices, loss = layer(residual)
374
- quantized = quantized.detach()
375
- residual = residual - quantized
376
- quantized_out = quantized_out + quantized
377
- all_indices.append(indices)
378
- all_losses.append(loss)
379
-
380
- if self.training:
381
- # Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
382
- quantized_out = x + (quantized_out - x).detach()
383
-
384
- out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
385
- return quantized_out, out_indices, out_losses
386
-
387
- def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
388
- residual = x
389
- all_indices = []
390
- n_q = n_q or len(self.layers)
391
- for layer in self.layers[:n_q]:
392
- indices = layer.encode(residual)
393
- quantized = layer.decode(indices)
394
- residual = residual - quantized
395
- all_indices.append(indices)
396
- out_indices = torch.stack(all_indices)
397
- return out_indices
398
-
399
- def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
400
- quantized_out = torch.tensor(0.0, device=q_indices.device)
401
- for i, indices in enumerate(q_indices):
402
- layer = self.layers[i]
403
- quantized = layer.decode(indices)
404
- quantized_out = quantized_out + quantized
405
- return quantized_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/{quantization/vq.py → vq.py} RENAMED
@@ -1,19 +1,157 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
  import math
8
  import typing as tp
9
-
 
10
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- from .base import BaseQuantizer, QuantizedResult
13
- from .core_vq import ResidualVectorQuantization
 
 
14
 
15
 
16
- class ResidualVectorQuantizer(BaseQuantizer):
17
  """Residual Vector Quantizer.
18
 
19
  Args:
@@ -59,6 +197,7 @@ class ResidualVectorQuantizer(BaseQuantizer):
59
  self.orthogonal_reg_weight = orthogonal_reg_weight
60
  self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
61
  self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
 
62
  self.vq = ResidualVectorQuantization(
63
  dim=self.dimension,
64
  codebook_size=self.bins,
@@ -66,10 +205,6 @@ class ResidualVectorQuantizer(BaseQuantizer):
66
  decay=self.decay,
67
  kmeans_init=self.kmeans_init,
68
  kmeans_iters=self.kmeans_iters,
69
- threshold_ema_dead_code=self.threshold_ema_dead_code,
70
- orthogonal_reg_weight=self.orthogonal_reg_weight,
71
- orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
72
- orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
73
  channels_last=False
74
  )
75
 
 
 
 
 
 
 
 
1
  import math
2
  import typing as tp
3
+ from dataclasses import dataclass, field
4
+ import typing as tp
5
  import torch
6
+ from torch import nn
7
+ from einops import rearrange
8
+ import torch.nn.functional as F
9
+
10
+ @dataclass
11
+ class QuantizedResult:
12
+ x: torch.Tensor
13
+ codes: torch.Tensor
14
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
15
+ penalty: tp.Optional[torch.Tensor] = None
16
+ metrics: dict = field(default_factory=dict)
17
+
18
+
19
+
20
+
21
+
22
+
23
+ class EuclideanCodebook(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim,
27
+ codebook_size,
28
+ kmeans_init=False,
29
+ kmeans_iters=10,
30
+ decay=0.8,
31
+ epsilon=1e-5,
32
+ ):
33
+ super().__init__()
34
+ self.decay=decay
35
+ init_fn=uniform_init if not kmeans_init else torch.zeros
36
+ embed = init_fn(codebook_size, dim)
37
+
38
+ self.codebook_size = codebook_size
39
+
40
+ self.kmeans_iters = kmeans_iters
41
+ self.epsilon = epsilon
42
+
43
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
44
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
45
+ self.register_buffer("embed", embed)
46
+ self.register_buffer("embed_avg", embed.clone())
47
+
48
+ @torch.jit.ignore
49
+ def init_embed_(self, data):
50
+ if self.inited:
51
+ return
52
+
53
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
54
+ self.embed.data.copy_(embed)
55
+ self.embed_avg.data.copy_(embed.clone())
56
+ self.cluster_size.data.copy_(cluster_size)
57
+ self.inited.data.copy_(torch.Tensor([True]))
58
+ # Make sure all buffers across workers are in sync after initialization
59
+ # flashy.distrib.broadcast_tensors(self.buffers()) # brodcast param values to all GPUS
60
+
61
+
62
+
63
+ def postprocess_emb(self, embed_ind, shape):
64
+ return embed_ind.view(*shape[:-1])
65
+
66
+ def dequantize(self, embed_ind):
67
+ quantize = F.embedding(embed_ind, self.embed)
68
+ # print('\n\nDE QUANT\n\n', quantize.shape) # (1, 35, 128) -> also arrives here for special_token
69
+ return quantize
70
+
71
+ def decode(self, embed_ind):
72
+ quantize = self.dequantize(embed_ind)
73
+ return quantize
74
+
75
+
76
+
77
+ class VectorQuantization(nn.Module):
78
+
79
+ def __init__(
80
+ self,
81
+ dim,
82
+ codebook_size,
83
+ codebook_dim=None,
84
+ decay=0.8,
85
+ epsilon=1e-5,
86
+ kmeans_init=False,
87
+ kmeans_iters=10,
88
+ channels_last=False,
89
+ ):
90
+ super().__init__()
91
+ # _codebook_dim: int = default(codebook_dim, dim)
92
+ _codebook_dim = codebook_dim if codebook_dim is not None else dim
93
+
94
+ requires_projection = _codebook_dim != dim
95
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
96
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
97
+ self._codebook = EuclideanCodebook(dim=_codebook_dim,
98
+ codebook_size=codebook_size,
99
+ kmeans_init=kmeans_init,
100
+ kmeans_iters=kmeans_iters,
101
+ decay=decay,
102
+ epsilon=epsilon)
103
+ self.codebook_size = codebook_size
104
+
105
+ self.channels_last = channels_last
106
+
107
+ @property
108
+ def codebook(self):
109
+ return self._codebook.embed
110
+
111
+ @property
112
+ def inited(self):
113
+ return self._codebook.inited
114
+
115
+ def _postprocess(self, quantize):
116
+ if not self.channels_last:
117
+ quantize = rearrange(quantize, "b n d -> b d n")
118
+ return quantize
119
+
120
+ def decode(self, embed_ind):
121
+ quantize = self._codebook.decode(embed_ind)
122
+ quantize = self.project_out(quantize)
123
+ quantize = self._postprocess(quantize)
124
+ return quantize
125
+
126
+
127
+
128
+
129
+ class ResidualVectorQuantization(nn.Module):
130
+ """Residual vector quantization implementation.
131
+
132
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
133
+ """
134
+ def __init__(self, *, num_quantizers, **kwargs):
135
+ super().__init__()
136
+ self.layers = nn.ModuleList(
137
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
138
+ )
139
+
140
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
141
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
142
+ for i, indices in enumerate(q_indices):
143
+ layer = self.layers[i]
144
+ quantized = layer.decode(indices)
145
+ quantized_out = quantized_out + quantized
146
+ return quantized_out
147
 
148
+
149
+
150
+
151
+ # ------------------------------------- END core_vq.py
152
 
153
 
154
+ class ResidualVectorQuantizer(nn.Module):
155
  """Residual Vector Quantizer.
156
 
157
  Args:
 
197
  self.orthogonal_reg_weight = orthogonal_reg_weight
198
  self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
199
  self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
200
+ print(f' {kmeans_init=}\n\n\n\n')
201
  self.vq = ResidualVectorQuantization(
202
  dim=self.dimension,
203
  codebook_size=self.bins,
 
205
  decay=self.decay,
206
  kmeans_init=self.kmeans_init,
207
  kmeans_iters=self.kmeans_iters,
 
 
 
 
208
  channels_last=False
209
  )
210
 
demo.py CHANGED
@@ -1,12 +1,13 @@
1
  from audiocraft.audiogen import AudioGen #, audio_write
2
-
 
3
 
4
  print('\n\n\n\n___________________')
5
 
6
- txt = 'austrian music'
7
 
8
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
9
- sound_generator.set_generation_params(duration=4.7) # why is generating so long at 14 seconds
10
 
11
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
12
  x /= np.abs(x).max() + 1e-7
 
1
  from audiocraft.audiogen import AudioGen #, audio_write
2
+ import audiofile
3
+ import numpy as np
4
 
5
  print('\n\n\n\n___________________')
6
 
7
+ txt = 'sea waves rock crash pirates'
8
 
9
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
+ sound_generator.set_generation_params(duration=.7) # why is generating so long at 14 seconds
11
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7