Dionyssos commited on
Commit
5067878
·
1 Parent(s): c5e1f80
README.md CHANGED
@@ -2,20 +2,17 @@
2
  license: mit
3
  language:
4
  - en
5
- pipeline_tag: text-to-speech
6
  tags:
7
  - audiocraft
8
  - audiogen
9
  - styletts2
10
- - audio
11
- - synthesis
12
  - shift
13
  - audeering
14
- - dkounadis
15
  - sound
16
- - scene
17
- - acoustic-scene
18
  - audio-generation
 
 
19
  ---
20
 
21
 
@@ -35,13 +32,23 @@ tags:
35
 
36
  ```
37
  git clone https://huggingface.co/dkounadis/artificial-styletts2
 
 
 
 
 
 
38
 
 
39
  virtualenv --python=python3 ~/.envs/.my_env
40
  source ~/.envs/.my_env/bin/activate
41
  cd artificial-styletts2/
42
  pip install -r requirements.txt
43
  ```
44
 
 
 
 
45
  Start Flask
46
 
47
  ```
@@ -128,4 +135,10 @@ Client - Describe any sound with words and it will be played back to you.
128
 
129
  ```python
130
  python live_demo.py # will ask text input & play soundscape
 
 
 
 
 
 
131
  ```
 
2
  license: mit
3
  language:
4
  - en
5
+ pipeline_tag: audio-generation
6
  tags:
7
  - audiocraft
8
  - audiogen
9
  - styletts2
 
 
10
  - shift
11
  - audeering
 
12
  - sound
 
 
13
  - audio-generation
14
+ - text-to-speech
15
+ - mimic3
16
  ---
17
 
18
 
 
32
 
33
  ```
34
  git clone https://huggingface.co/dkounadis/artificial-styletts2
35
+ ```
36
+
37
+ <details>
38
+ <summary>
39
+ Create virtualenv
40
+ </summary>
41
 
42
+ ```
43
  virtualenv --python=python3 ~/.envs/.my_env
44
  source ~/.envs/.my_env/bin/activate
45
  cd artificial-styletts2/
46
  pip install -r requirements.txt
47
  ```
48
 
49
+
50
+ </details>
51
+
52
  Start Flask
53
 
54
  ```
 
135
 
136
  ```python
137
  python live_demo.py # will ask text input & play soundscape
138
+ ```
139
+
140
+ # Simple Demo
141
+
142
+ ```python
143
+ CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=4 python demo.py
144
  ```
audiocraft/builders.py CHANGED
@@ -15,7 +15,7 @@ import audiocraft
15
  import omegaconf
16
  import torch
17
 
18
- from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
19
  from .lm import LMModel
20
  from .seanet import SEANetEncoder, SEANetDecoder
21
  from .codebooks_patterns import (
@@ -211,20 +211,3 @@ def get_processor(cfg, sample_rate: int = 24000):
211
  if cfg.name == "multi_band_processor":
212
  sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
213
  return sample_processor
214
-
215
-
216
-
217
-
218
-
219
- def get_wrapped_compression_model(
220
- compression_model: CompressionModel,
221
- cfg: omegaconf.DictConfig) -> CompressionModel:
222
- if hasattr(cfg, 'interleave_stereo_codebooks'):
223
- if cfg.interleave_stereo_codebooks.use:
224
- kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
225
- kwargs.pop('use')
226
- compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
227
- if hasattr(cfg, 'compression_model_n_q'):
228
- if cfg.compression_model_n_q is not None:
229
- compression_model.set_num_codebooks(cfg.compression_model_n_q)
230
- return compression_model
 
15
  import omegaconf
16
  import torch
17
 
18
+ from .encodec import CompressionModel, EncodecModel
19
  from .lm import LMModel
20
  from .seanet import SEANetEncoder, SEANetDecoder
21
  from .codebooks_patterns import (
 
211
  if cfg.name == "multi_band_processor":
212
  sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
213
  return sample_processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/conditioners.py CHANGED
@@ -1,5 +1,4 @@
1
  from collections import defaultdict
2
- from copy import deepcopy
3
  from dataclasses import dataclass, field
4
  from itertools import chain
5
  import logging
@@ -10,20 +9,12 @@ import re
10
  import typing as tp
11
  import warnings
12
  import soundfile
13
- from num2words import num2words
14
- import spacy
15
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
16
  import torch
17
  from torch import nn
18
- import torch.nn.functional as F
19
- from torch.nn.utils.rnn import pad_sequence
20
  from .streaming import StreamingModule
21
 
22
 
23
- from .streaming import StreamingModule
24
- from .transformer import create_sin_embedding
25
-
26
-
27
  from .quantization import ResidualVectorQuantizer
28
  from .utils.autocast import TorchAutocast
29
  from .utils.cache import EmbeddingCache
@@ -112,102 +103,10 @@ class Tokenizer:
112
  raise NotImplementedError()
113
 
114
 
115
- class WhiteSpaceTokenizer(Tokenizer):
116
- """This tokenizer should be used for natural language descriptions.
117
- For example:
118
- ["he didn't, know he's going home.", 'shorter sentence'] =>
119
- [[78, 62, 31, 4, 78, 25, 19, 34],
120
- [59, 77, 0, 0, 0, 0, 0, 0]]
121
- """
122
- PUNCTUATION = "?:!.,;"
123
-
124
- def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
125
- lemma: bool = True, stopwords: bool = True) -> None:
126
- self.n_bins = n_bins
127
- self.pad_idx = pad_idx
128
- self.lemma = lemma
129
- self.stopwords = stopwords
130
- try:
131
- self.nlp = spacy.load(language)
132
- except IOError:
133
- spacy.cli.download(language) # type: ignore
134
- self.nlp = spacy.load(language)
135
-
136
- @tp.no_type_check
137
- def __call__(self, texts: tp.List[tp.Optional[str]],
138
- return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
139
- """Take a list of strings and convert them to a tensor of indices.
140
 
141
- Args:
142
- texts (list[str]): List of strings.
143
- return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
144
- Returns:
145
- tuple[torch.Tensor, torch.Tensor]:
146
- - Indices of words in the LUT.
147
- - And a mask indicating where the padding tokens are
148
- """
149
- output, lengths = [], []
150
- texts = deepcopy(texts)
151
- for i, text in enumerate(texts):
152
- # if current sample doesn't have a certain attribute, replace with pad token
153
- if text is None:
154
- output.append(torch.Tensor([self.pad_idx]))
155
- lengths.append(0)
156
- continue
157
-
158
- # convert numbers to words
159
- text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
160
- # normalize text
161
- text = self.nlp(text) # type: ignore
162
- # remove stopwords
163
- if self.stopwords:
164
- text = [w for w in text if not w.is_stop] # type: ignore
165
- # remove punctuation
166
- text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
167
- # lemmatize if needed
168
- text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
169
-
170
- texts[i] = " ".join(text)
171
- lengths.append(len(text))
172
- # convert to tensor
173
- tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
174
- output.append(tokens)
175
-
176
- mask = length_to_mask(torch.IntTensor(lengths)).int()
177
- padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
178
- if return_text:
179
- return padded_output, mask, texts # type: ignore
180
- return padded_output, mask
181
-
182
-
183
- class NoopTokenizer(Tokenizer):
184
- """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
185
- The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
186
- strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
187
- split it to ["Jeff", "Buckley"] and return an index per word.
188
-
189
- For example:
190
- ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
191
- ["Metal", "Rock", "Classical"] => [0, 223, 51]
192
- """
193
- def __init__(self, n_bins: int, pad_idx: int = 0):
194
- self.n_bins = n_bins
195
- self.pad_idx = pad_idx
196
 
197
- def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
198
- output, lengths = [], []
199
- for text in texts:
200
- # if current sample doesn't have a certain attribute, replace with pad token
201
- if text is None:
202
- output.append(self.pad_idx)
203
- lengths.append(0)
204
- else:
205
- output.append(hash_trick(text, self.n_bins))
206
- lengths.append(1)
207
-
208
- tokens = torch.LongTensor(output).unsqueeze(1)
209
- mask = length_to_mask(torch.IntTensor(lengths)).int()
210
- return tokens, mask
211
 
212
 
213
  class BaseConditioner(nn.Module):
 
1
  from collections import defaultdict
 
2
  from dataclasses import dataclass, field
3
  from itertools import chain
4
  import logging
 
9
  import typing as tp
10
  import warnings
11
  import soundfile
 
 
12
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
13
  import torch
14
  from torch import nn
 
 
15
  from .streaming import StreamingModule
16
 
17
 
 
 
 
 
18
  from .quantization import ResidualVectorQuantizer
19
  from .utils.autocast import TorchAutocast
20
  from .utils.cache import EmbeddingCache
 
103
  raise NotImplementedError()
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+
109
+
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  class BaseConditioner(nn.Module):
audiocraft/encodec.py CHANGED
@@ -256,251 +256,4 @@ class EncodecModel(CompressionModel):
256
 
257
  def decode_latent(self, codes: torch.Tensor):
258
  """Decode from the discrete codes to continuous latent space."""
259
- return self.quantizer.decode(codes)
260
-
261
-
262
- class DAC(CompressionModel):
263
- def __init__(self, model_type: str = "44khz"):
264
- super().__init__()
265
- try:
266
- import dac.utils
267
- except ImportError:
268
- raise RuntimeError("Could not import dac, make sure it is installed, "
269
- "please run `pip install descript-audio-codec`")
270
- self.model = dac.utils.load_model(model_type=model_type)
271
- self.n_quantizers = self.total_codebooks
272
- self.model.eval()
273
-
274
- def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
275
- # We don't support training with this.
276
- raise NotImplementedError("Forward and training with DAC not supported.")
277
-
278
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
279
- codes = self.model.encode(x, self.n_quantizers)[1]
280
- return codes[:, :self.n_quantizers], None
281
-
282
- def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
283
- assert scale is None
284
- z_q = self.decode_latent(codes)
285
- return self.model.decode(z_q)
286
-
287
- def decode_latent(self, codes: torch.Tensor):
288
- """Decode from the discrete codes to continuous latent space."""
289
- return self.model.quantizer.from_codes(codes)[0]
290
-
291
- @property
292
- def channels(self) -> int:
293
- return 1
294
-
295
- @property
296
- def frame_rate(self) -> float:
297
- return self.model.sample_rate / self.model.hop_length
298
-
299
- @property
300
- def sample_rate(self) -> int:
301
- return self.model.sample_rate
302
-
303
- @property
304
- def cardinality(self) -> int:
305
- return self.model.codebook_size
306
-
307
- @property
308
- def num_codebooks(self) -> int:
309
- return self.n_quantizers
310
-
311
- @property
312
- def total_codebooks(self) -> int:
313
- return self.model.n_codebooks
314
-
315
- def set_num_codebooks(self, n: int):
316
- """Set the active number of codebooks used by the quantizer.
317
- """
318
- assert n >= 1
319
- assert n <= self.total_codebooks
320
- self.n_quantizers = n
321
-
322
-
323
- class HFEncodecCompressionModel(CompressionModel):
324
- """Wrapper around HuggingFace Encodec.
325
- """
326
- def __init__(self, model: HFEncodecModel):
327
- super().__init__()
328
- self.model = model
329
- bws = self.model.config.target_bandwidths
330
- num_codebooks = [
331
- bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
332
- for bw in bws
333
- ]
334
- deltas = [nc - int(nc) for nc in num_codebooks]
335
- # Checking we didn't do some bad maths and we indeed have integers!
336
- assert all(deltas) <= 1e-3, deltas
337
- self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
338
- self.set_num_codebooks(max(self.possible_num_codebooks))
339
-
340
- def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
341
- # We don't support training with this.
342
- raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
343
-
344
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
345
- bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
346
- bandwidth = self.model.config.target_bandwidths[bandwidth_index]
347
- res = self.model.encode(x, None, bandwidth)
348
- assert len(res[0]) == 1
349
- assert len(res[1]) == 1
350
- return res[0][0], res[1][0]
351
-
352
- def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
353
- if scale is None:
354
- scales = [None] # type: ignore
355
- else:
356
- scales = scale # type: ignore
357
- res = self.model.decode(codes[None], scales)
358
- return res[0]
359
-
360
- def decode_latent(self, codes: torch.Tensor):
361
- """Decode from the discrete codes to continuous latent space."""
362
- return self.model.quantizer.decode(codes.transpose(0, 1))
363
-
364
- @property
365
- def channels(self) -> int:
366
- return self.model.config.audio_channels
367
-
368
- @property
369
- def frame_rate(self) -> float:
370
- hop_length = int(np.prod(self.model.config.upsampling_ratios))
371
- return self.sample_rate / hop_length
372
-
373
- @property
374
- def sample_rate(self) -> int:
375
- return self.model.config.sampling_rate
376
-
377
- @property
378
- def cardinality(self) -> int:
379
- return self.model.config.codebook_size
380
-
381
- @property
382
- def num_codebooks(self) -> int:
383
- return self._num_codebooks
384
-
385
- @property
386
- def total_codebooks(self) -> int:
387
- return max(self.possible_num_codebooks)
388
-
389
- def set_num_codebooks(self, n: int):
390
- """Set the active number of codebooks used by the quantizer.
391
- """
392
- if n not in self.possible_num_codebooks:
393
- raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
394
- self._num_codebooks = n
395
-
396
-
397
- class InterleaveStereoCompressionModel(CompressionModel):
398
- """Wraps a CompressionModel to support stereo inputs. The wrapped model
399
- will be applied independently to the left and right channels, and both codebooks
400
- will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
401
- channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
402
- `per_timestep`.
403
-
404
- Args:
405
- model (CompressionModel): Compression model to wrap.
406
- per_timestep (bool): Whether to interleave on the timestep dimension
407
- or on the codebooks dimension.
408
- """
409
- def __init__(self, model: CompressionModel, per_timestep: bool = False):
410
- super().__init__()
411
- self.model = model
412
- self.per_timestep = per_timestep
413
- assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
414
-
415
- @property
416
- def total_codebooks(self):
417
- return self.model.total_codebooks
418
-
419
- @property
420
- def num_codebooks(self):
421
- """Active number of codebooks used by the quantizer.
422
-
423
- ..Warning:: this reports the number of codebooks after the interleaving
424
- of the codebooks!
425
- """
426
- return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
427
-
428
- def set_num_codebooks(self, n: int):
429
- """Set the active number of codebooks used by the quantizer.
430
-
431
- ..Warning:: this sets the number of codebooks before the interleaving!
432
- """
433
- self.model.set_num_codebooks(n)
434
-
435
- @property
436
- def num_virtual_steps(self) -> float:
437
- """Return the number of virtual steps, e.g. one real step
438
- will be split into that many steps.
439
- """
440
- return 2 if self.per_timestep else 1
441
-
442
- @property
443
- def frame_rate(self) -> float:
444
- return self.model.frame_rate * self.num_virtual_steps
445
-
446
- @property
447
- def sample_rate(self) -> int:
448
- return self.model.sample_rate
449
-
450
- @property
451
- def channels(self) -> int:
452
- return 2
453
-
454
- @property
455
- def cardinality(self):
456
- """Cardinality of each codebook.
457
- """
458
- return self.model.cardinality
459
-
460
- def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
461
- raise NotImplementedError("Not supported, use encode and decode.")
462
-
463
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
464
- B, C, T = x.shape
465
- assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
466
-
467
- indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
468
- indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
469
- indices = torch.stack([indices_c0, indices_c1], dim=0)
470
- scales: tp.Optional[torch.Tensor] = None
471
- if scales_c0 is not None and scales_c1 is not None:
472
- scales = torch.stack([scales_c0, scales_c1], dim=1)
473
-
474
- if self.per_timestep:
475
- indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
476
- else:
477
- indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
478
-
479
- return (indices, scales)
480
-
481
- def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
482
- if self.per_timestep:
483
- codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
484
- else:
485
- codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
486
- return codes[0], codes[1]
487
-
488
- def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
489
- B, K, T = codes.shape
490
- assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
491
- assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
492
-
493
- scale_c0, scale_c1 = None, None
494
- if scale is not None:
495
- assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
496
- scale_c0 = scale[0, ...]
497
- scale_c1 = scale[1, ...]
498
-
499
- codes_c0, codes_c1 = self.get_left_right_codes(codes)
500
- audio_c0 = self.model.decode(codes_c0, scale_c0)
501
- audio_c1 = self.model.decode(codes_c1, scale_c1)
502
- return torch.cat([audio_c0, audio_c1], dim=1)
503
-
504
- def decode_latent(self, codes: torch.Tensor):
505
- """Decode from the discrete codes to continuous latent space."""
506
- raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
 
256
 
257
  def decode_latent(self, codes: torch.Tensor):
258
  """Decode from the discrete codes to continuous latent space."""
259
+ return self.quantizer.decode(codes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/genmodel.py CHANGED
@@ -6,7 +6,6 @@ import torch
6
 
7
  from .encodec import CompressionModel
8
  from .lm import LMModel
9
- from .builders import get_wrapped_compression_model
10
  from .utils.audio_utils import convert_audio
11
  from .conditioners import ConditioningAttributes
12
  from .utils.autocast import TorchAutocast
@@ -38,9 +37,6 @@ class BaseGenModel(ABC):
38
  assert isinstance(cfg, omegaconf.DictConfig)
39
  self.cfg = cfg
40
 
41
- if self.cfg is not None:
42
- self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
43
-
44
  if max_duration is None:
45
  if self.cfg is not None:
46
  max_duration = lm.cfg.dataset.segment_duration # type: ignore
 
6
 
7
  from .encodec import CompressionModel
8
  from .lm import LMModel
 
9
  from .utils.audio_utils import convert_audio
10
  from .conditioners import ConditioningAttributes
11
  from .utils.autocast import TorchAutocast
 
37
  assert isinstance(cfg, omegaconf.DictConfig)
38
  self.cfg = cfg
39
 
 
 
 
40
  if max_duration is None:
41
  if self.cfg is not None:
42
  max_duration = lm.cfg.dataset.segment_duration # type: ignore
audiocraft/multibanddiffusion.py DELETED
@@ -1,392 +0,0 @@
1
- #====================================== From CompressionSolver.py
2
-
3
- # Copyright (c) Meta Platforms, Inc. and affiliates.
4
- # All rights reserved.
5
- #
6
- # This source code is licensed under the license found in the
7
- # LICENSE file in the root directory of this source tree.
8
-
9
- import logging
10
- import multiprocessing
11
- from pathlib import Path
12
- import typing as tp
13
-
14
- import flashy
15
- import omegaconf
16
- import torch
17
- from torch import nn
18
-
19
- # from . import base, builders
20
- from .. import models, quantization
21
- from ..utils import checkpoint
22
- from ..utils.samples.manager import SampleManager
23
- from ..utils.utils import get_pool_executor
24
-
25
-
26
-
27
-
28
-
29
- class CompressionSolver(): #base.StandardSolver):
30
- """Solver for compression task.
31
-
32
- The compression task combines a set of perceptual and objective losses
33
- to train an EncodecModel (composed of an encoder-decoder and a quantizer)
34
- to perform high fidelity audio reconstruction.
35
- """
36
- def __init__(self, cfg: omegaconf.DictConfig):
37
- # super().__init__(cfg)
38
- self.cfg = cfg
39
- self.rng: torch.Generator # set at each epoch
40
- self.adv_losses = builders.get_adversarial_losses(self.cfg)
41
- self.aux_losses = nn.ModuleDict()
42
- self.info_losses = nn.ModuleDict()
43
- assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
44
- loss_weights = dict()
45
- for loss_name, weight in self.cfg.losses.items():
46
- if loss_name in ['adv', 'feat']:
47
- for adv_name, _ in self.adv_losses.items():
48
- loss_weights[f'{loss_name}_{adv_name}'] = weight
49
- elif weight > 0:
50
- self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
51
- loss_weights[loss_name] = weight
52
- else:
53
- self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
54
- self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)
55
- self.register_stateful('adv_losses')
56
-
57
- @property
58
- def best_metric_name(self) -> tp.Optional[str]:
59
- # best model is the last for the compression model
60
- return None
61
-
62
- def build_model(self):
63
- """Instantiate model and optimizer."""
64
- # Model and optimizer
65
- self.model = models.builders.get_compression_model(self.cfg).to(self.device)
66
- self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
67
- self.register_stateful('model', 'optimizer')
68
- self.register_best_state('model')
69
- self.register_ema('model')
70
-
71
-
72
-
73
- def evaluate(self):
74
- """Evaluate stage. Runs audio reconstruction evaluation."""
75
- self.model.eval()
76
- evaluate_stage_name = str(self.current_stage)
77
-
78
- loader = self.dataloaders['evaluate']
79
- updates = len(loader)
80
- lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
81
- average = flashy.averager()
82
-
83
- pendings = []
84
- ctx = multiprocessing.get_context('spawn')
85
- with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
86
- for idx, batch in enumerate(lp):
87
- x = batch.to(self.device)
88
- with torch.no_grad():
89
- qres = self.model(x)
90
-
91
- y_pred = qres.x.cpu()
92
- y = batch.cpu() # should already be on CPU but just in case
93
- pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
94
-
95
- metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
96
- for pending in metrics_lp:
97
- metrics = pending.result()
98
- metrics = average(metrics)
99
-
100
- metrics = flashy.distrib.average_metrics(metrics, len(loader))
101
- return metrics
102
-
103
- def generate(self):
104
- """Generate stage."""
105
- self.model.eval()
106
- sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
107
- generate_stage_name = str(self.current_stage)
108
-
109
- loader = self.dataloaders['generate']
110
- updates = len(loader)
111
- lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
112
-
113
- for batch in lp:
114
- reference, _ = batch
115
- reference = reference.to(self.device)
116
- with torch.no_grad():
117
- qres = self.model(reference)
118
- assert isinstance(qres, quantization.QuantizedResult)
119
-
120
- reference = reference.cpu()
121
- estimate = qres.x.cpu()
122
- sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
123
-
124
- flashy.distrib.barrier()
125
-
126
- def load_from_pretrained(self, name: str) -> dict:
127
- model = models.CompressionModel.get_pretrained(name)
128
- if isinstance(model, models.DAC):
129
- raise RuntimeError("Cannot fine tune a DAC model.")
130
- elif isinstance(model, models.HFEncodecCompressionModel):
131
- self.logger.warning('Trying to automatically convert a HuggingFace model '
132
- 'to AudioCraft, this might fail!')
133
- state = model.model.state_dict()
134
- new_state = {}
135
- for k, v in state.items():
136
- if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
137
- # We need to determine if this a convtr or a regular conv.
138
- layer = int(k.split('.')[2])
139
- if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
140
-
141
- k = k.replace('.conv.', '.convtr.')
142
- k = k.replace('encoder.layers.', 'encoder.model.')
143
- k = k.replace('decoder.layers.', 'decoder.model.')
144
- k = k.replace('conv.', 'conv.conv.')
145
- k = k.replace('convtr.', 'convtr.convtr.')
146
- k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
147
- k = k.replace('.codebook.', '._codebook.')
148
- new_state[k] = v
149
- state = new_state
150
- elif isinstance(model, models.EncodecModel):
151
- state = model.state_dict()
152
- else:
153
- raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
154
- return {
155
- 'best_state': {'model': state}
156
- }
157
-
158
- @staticmethod
159
- def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
160
- device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
161
- """Instantiate a CompressionModel from a given checkpoint path or dora sig.
162
- This method is a convenient endpoint to load a CompressionModel to use in other solvers.
163
-
164
- Args:
165
- checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
166
- This also supports pre-trained models by using a path of the form //pretrained/NAME.
167
- See `model_from_pretrained` for a list of supported pretrained models.
168
- use_ema (bool): Use EMA variant of the model instead of the actual model.
169
- device (torch.device or str): Device on which the model is loaded.
170
- """
171
- checkpoint_path = str(checkpoint_path)
172
- if checkpoint_path.startswith('//pretrained/'):
173
- name = checkpoint_path.split('/', 3)[-1]
174
- return models.CompressionModel.get_pretrained(name, device)
175
- logger = logging.getLogger(__name__)
176
- logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
177
- _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
178
- assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
179
- state = checkpoint.load_checkpoint(_checkpoint_path)
180
- assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
181
- cfg = state['xp.cfg']
182
- cfg.device = device
183
- compression_model = models.builders.get_compression_model(cfg).to(device)
184
- assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
185
-
186
- assert 'best_state' in state and state['best_state'] != {}
187
- assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
188
- compression_model.load_state_dict(state['best_state']['model'])
189
- compression_model.eval()
190
- logger.info("Compression model loaded!")
191
- return compression_model
192
-
193
- @staticmethod
194
- def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
195
- checkpoint_path: tp.Union[Path, str],
196
- device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
197
- """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
198
-
199
- Args:
200
- cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
201
- checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
202
- use_ema (bool): Use EMA variant of the model instead of the actual model.
203
- device (torch.device or str): Device on which the model is loaded.
204
- """
205
- compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
206
- compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
207
- return compression_model
208
-
209
-
210
-
211
-
212
-
213
- #=========================================================================== ORIG
214
-
215
- import typing as tp
216
-
217
- import torch
218
- import julius
219
-
220
- from .unet import DiffusionUnet
221
- from ..modules.diffusion_schedule import NoiseSchedule
222
- from .encodec import CompressionModel
223
- from .loaders import load_compression_model, load_diffusion_models
224
-
225
-
226
- class DiffusionProcess:
227
- """Sampling for a diffusion Model.
228
-
229
- Args:
230
- model (DiffusionUnet): Diffusion U-Net model.
231
- noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
232
- """
233
- def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
234
- self.model = model
235
- self.schedule = noise_schedule
236
-
237
- def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
238
- step_list: tp.Optional[tp.List[int]] = None):
239
- """Perform one diffusion process to generate one of the bands.
240
-
241
- Args:
242
- condition (torch.Tensor): The embeddings from the compression model.
243
- initial_noise (torch.Tensor): The initial noise to start the process.
244
- """
245
- return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
246
- condition=condition)
247
-
248
-
249
- class MultiBandDiffusion:
250
- """Sample from multiple diffusion models.
251
-
252
- Args:
253
- DPs (list of DiffusionProcess): Diffusion processes.
254
- codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
255
- """
256
- def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
257
- self.DPs = DPs
258
- self.codec_model = codec_model
259
- self.device = next(self.codec_model.parameters()).device
260
-
261
- @property
262
- def sample_rate(self) -> int:
263
- return self.codec_model.sample_rate
264
-
265
- @staticmethod
266
- def get_mbd_musicgen(device=None):
267
- """Load our diffusion models trained for MusicGen."""
268
- if device is None:
269
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
270
- path = 'facebook/multiband-diffusion'
271
- filename = 'mbd_musicgen_32khz.th'
272
- name = 'facebook/musicgen-small'
273
- codec_model = load_compression_model(name, device=device)
274
- models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
275
- DPs = []
276
- for i in range(len(models)):
277
- schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
278
- DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
279
- return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
280
-
281
- @staticmethod
282
- def get_mbd_24khz(bw: float = 3.0,
283
- device: tp.Optional[tp.Union[torch.device, str]] = None,
284
- n_q: tp.Optional[int] = None):
285
- """Get the pretrained Models for MultibandDiffusion.
286
-
287
- Args:
288
- bw (float): Bandwidth of the compression model.
289
- device (torch.device or str, optional): Device on which the models are loaded.
290
- n_q (int, optional): Number of quantizers to use within the compression model.
291
- """
292
- if device is None:
293
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
294
- assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
295
- if n_q is not None:
296
- assert n_q in [2, 4, 8]
297
- assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
298
- f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
299
- n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
300
- codec_model = CompressionSolver.model_from_checkpoint(
301
- '//pretrained/facebook/encodec_24khz', device=device)
302
- codec_model.set_num_codebooks(n_q)
303
- codec_model = codec_model.to(device)
304
- path = 'facebook/multiband-diffusion'
305
- filename = f'mbd_comp_{n_q}.pt'
306
- models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
307
- DPs = []
308
- for i in range(len(models)):
309
- schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
310
- DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
311
- return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
312
-
313
- @torch.no_grad()
314
- def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
315
- """Get the conditioning (i.e. latent representations of the compression model) from a waveform.
316
- Args:
317
- wav (torch.Tensor): The audio that we want to extract the conditioning from.
318
- sample_rate (int): Sample rate of the audio."""
319
- if sample_rate != self.sample_rate:
320
- wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
321
- codes, scale = self.codec_model.encode(wav)
322
- assert scale is None, "Scaled compression models not supported."
323
- emb = self.get_emb(codes)
324
- return emb
325
-
326
- @torch.no_grad()
327
- def get_emb(self, codes: torch.Tensor):
328
- """Get latent representation from the discrete codes.
329
- Args:
330
- codes (torch.Tensor): Discrete tokens."""
331
- emb = self.codec_model.decode_latent(codes)
332
- return emb
333
-
334
- def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
335
- step_list: tp.Optional[tp.List[int]] = None):
336
- """Generate waveform audio from the latent embeddings of the compression model.
337
- Args:
338
- emb (torch.Tensor): Conditioning embeddings
339
- size (None, torch.Size): Size of the output
340
- if None this is computed from the typical upsampling of the model.
341
- step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
342
- """
343
- if size is None:
344
- upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
345
- size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
346
- assert size[0] == emb.size(0)
347
- out = torch.zeros(size).to(self.device)
348
- for DP in self.DPs:
349
- out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
350
- return out
351
-
352
- def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
353
- """Match the eq to the encodec output by matching the standard deviation of some frequency bands.
354
- Args:
355
- wav (torch.Tensor): Audio to equalize.
356
- ref (torch.Tensor): Reference audio from which we match the spectrogram.
357
- n_bands (int): Number of bands of the eq.
358
- strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
359
- """
360
- split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
361
- bands = split(wav)
362
- bands_ref = split(ref)
363
- out = torch.zeros_like(ref)
364
- for i in range(n_bands):
365
- out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
366
- return out
367
-
368
- def regenerate(self, wav: torch.Tensor, sample_rate: int):
369
- """Regenerate a waveform through compression and diffusion regeneration.
370
- Args:
371
- wav (torch.Tensor): Original 'ground truth' audio.
372
- sample_rate (int): Sample rate of the input (and output) wav.
373
- """
374
- if sample_rate != self.codec_model.sample_rate:
375
- wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
376
- emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
377
- size = wav.size()
378
- out = self.generate(emb, size=size)
379
- if sample_rate != self.codec_model.sample_rate:
380
- out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
381
- return out
382
-
383
- def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
384
- """Generate Waveform audio with diffusion from the discrete codes.
385
- Args:
386
- tokens (torch.Tensor): Discrete codes.
387
- n_bands (int): Bands for the eq matching.
388
- """
389
- wav_encodec = self.codec_model.decode(tokens)
390
- condition = self.get_emb(tokens)
391
- wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
392
- return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py CHANGED
@@ -12,4 +12,4 @@ sound_generator.set_generation_params(duration=1) # why is generating so long
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7
14
 
15
- audiofile.write('_audio_.wav', x, 16000)
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7
14
 
15
+ audiofile.write('_audio3_.wav', x, 16000)