cleanup 3
Browse files- README.md +19 -6
- audiocraft/builders.py +1 -18
- audiocraft/conditioners.py +2 -103
- audiocraft/encodec.py +1 -248
- audiocraft/genmodel.py +0 -4
- audiocraft/multibanddiffusion.py +0 -392
- demo.py +1 -1
README.md
CHANGED
@@ -2,20 +2,17 @@
|
|
2 |
license: mit
|
3 |
language:
|
4 |
- en
|
5 |
-
pipeline_tag:
|
6 |
tags:
|
7 |
- audiocraft
|
8 |
- audiogen
|
9 |
- styletts2
|
10 |
-
- audio
|
11 |
-
- synthesis
|
12 |
- shift
|
13 |
- audeering
|
14 |
-
- dkounadis
|
15 |
- sound
|
16 |
-
- scene
|
17 |
-
- acoustic-scene
|
18 |
- audio-generation
|
|
|
|
|
19 |
---
|
20 |
|
21 |
|
@@ -35,13 +32,23 @@ tags:
|
|
35 |
|
36 |
```
|
37 |
git clone https://huggingface.co/dkounadis/artificial-styletts2
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
|
|
39 |
virtualenv --python=python3 ~/.envs/.my_env
|
40 |
source ~/.envs/.my_env/bin/activate
|
41 |
cd artificial-styletts2/
|
42 |
pip install -r requirements.txt
|
43 |
```
|
44 |
|
|
|
|
|
|
|
45 |
Start Flask
|
46 |
|
47 |
```
|
@@ -128,4 +135,10 @@ Client - Describe any sound with words and it will be played back to you.
|
|
128 |
|
129 |
```python
|
130 |
python live_demo.py # will ask text input & play soundscape
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
```
|
|
|
2 |
license: mit
|
3 |
language:
|
4 |
- en
|
5 |
+
pipeline_tag: audio-generation
|
6 |
tags:
|
7 |
- audiocraft
|
8 |
- audiogen
|
9 |
- styletts2
|
|
|
|
|
10 |
- shift
|
11 |
- audeering
|
|
|
12 |
- sound
|
|
|
|
|
13 |
- audio-generation
|
14 |
+
- text-to-speech
|
15 |
+
- mimic3
|
16 |
---
|
17 |
|
18 |
|
|
|
32 |
|
33 |
```
|
34 |
git clone https://huggingface.co/dkounadis/artificial-styletts2
|
35 |
+
```
|
36 |
+
|
37 |
+
<details>
|
38 |
+
<summary>
|
39 |
+
Create virtualenv
|
40 |
+
</summary>
|
41 |
|
42 |
+
```
|
43 |
virtualenv --python=python3 ~/.envs/.my_env
|
44 |
source ~/.envs/.my_env/bin/activate
|
45 |
cd artificial-styletts2/
|
46 |
pip install -r requirements.txt
|
47 |
```
|
48 |
|
49 |
+
|
50 |
+
</details>
|
51 |
+
|
52 |
Start Flask
|
53 |
|
54 |
```
|
|
|
135 |
|
136 |
```python
|
137 |
python live_demo.py # will ask text input & play soundscape
|
138 |
+
```
|
139 |
+
|
140 |
+
# Simple Demo
|
141 |
+
|
142 |
+
```python
|
143 |
+
CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=4 python demo.py
|
144 |
```
|
audiocraft/builders.py
CHANGED
@@ -15,7 +15,7 @@ import audiocraft
|
|
15 |
import omegaconf
|
16 |
import torch
|
17 |
|
18 |
-
from .encodec import CompressionModel, EncodecModel
|
19 |
from .lm import LMModel
|
20 |
from .seanet import SEANetEncoder, SEANetDecoder
|
21 |
from .codebooks_patterns import (
|
@@ -211,20 +211,3 @@ def get_processor(cfg, sample_rate: int = 24000):
|
|
211 |
if cfg.name == "multi_band_processor":
|
212 |
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
213 |
return sample_processor
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
def get_wrapped_compression_model(
|
220 |
-
compression_model: CompressionModel,
|
221 |
-
cfg: omegaconf.DictConfig) -> CompressionModel:
|
222 |
-
if hasattr(cfg, 'interleave_stereo_codebooks'):
|
223 |
-
if cfg.interleave_stereo_codebooks.use:
|
224 |
-
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
|
225 |
-
kwargs.pop('use')
|
226 |
-
compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
|
227 |
-
if hasattr(cfg, 'compression_model_n_q'):
|
228 |
-
if cfg.compression_model_n_q is not None:
|
229 |
-
compression_model.set_num_codebooks(cfg.compression_model_n_q)
|
230 |
-
return compression_model
|
|
|
15 |
import omegaconf
|
16 |
import torch
|
17 |
|
18 |
+
from .encodec import CompressionModel, EncodecModel
|
19 |
from .lm import LMModel
|
20 |
from .seanet import SEANetEncoder, SEANetDecoder
|
21 |
from .codebooks_patterns import (
|
|
|
211 |
if cfg.name == "multi_band_processor":
|
212 |
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
213 |
return sample_processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/conditioners.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
from collections import defaultdict
|
2 |
-
from copy import deepcopy
|
3 |
from dataclasses import dataclass, field
|
4 |
from itertools import chain
|
5 |
import logging
|
@@ -10,20 +9,12 @@ import re
|
|
10 |
import typing as tp
|
11 |
import warnings
|
12 |
import soundfile
|
13 |
-
from num2words import num2words
|
14 |
-
import spacy
|
15 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
16 |
import torch
|
17 |
from torch import nn
|
18 |
-
import torch.nn.functional as F
|
19 |
-
from torch.nn.utils.rnn import pad_sequence
|
20 |
from .streaming import StreamingModule
|
21 |
|
22 |
|
23 |
-
from .streaming import StreamingModule
|
24 |
-
from .transformer import create_sin_embedding
|
25 |
-
|
26 |
-
|
27 |
from .quantization import ResidualVectorQuantizer
|
28 |
from .utils.autocast import TorchAutocast
|
29 |
from .utils.cache import EmbeddingCache
|
@@ -112,102 +103,10 @@ class Tokenizer:
|
|
112 |
raise NotImplementedError()
|
113 |
|
114 |
|
115 |
-
class WhiteSpaceTokenizer(Tokenizer):
|
116 |
-
"""This tokenizer should be used for natural language descriptions.
|
117 |
-
For example:
|
118 |
-
["he didn't, know he's going home.", 'shorter sentence'] =>
|
119 |
-
[[78, 62, 31, 4, 78, 25, 19, 34],
|
120 |
-
[59, 77, 0, 0, 0, 0, 0, 0]]
|
121 |
-
"""
|
122 |
-
PUNCTUATION = "?:!.,;"
|
123 |
-
|
124 |
-
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
|
125 |
-
lemma: bool = True, stopwords: bool = True) -> None:
|
126 |
-
self.n_bins = n_bins
|
127 |
-
self.pad_idx = pad_idx
|
128 |
-
self.lemma = lemma
|
129 |
-
self.stopwords = stopwords
|
130 |
-
try:
|
131 |
-
self.nlp = spacy.load(language)
|
132 |
-
except IOError:
|
133 |
-
spacy.cli.download(language) # type: ignore
|
134 |
-
self.nlp = spacy.load(language)
|
135 |
-
|
136 |
-
@tp.no_type_check
|
137 |
-
def __call__(self, texts: tp.List[tp.Optional[str]],
|
138 |
-
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
139 |
-
"""Take a list of strings and convert them to a tensor of indices.
|
140 |
|
141 |
-
Args:
|
142 |
-
texts (list[str]): List of strings.
|
143 |
-
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
|
144 |
-
Returns:
|
145 |
-
tuple[torch.Tensor, torch.Tensor]:
|
146 |
-
- Indices of words in the LUT.
|
147 |
-
- And a mask indicating where the padding tokens are
|
148 |
-
"""
|
149 |
-
output, lengths = [], []
|
150 |
-
texts = deepcopy(texts)
|
151 |
-
for i, text in enumerate(texts):
|
152 |
-
# if current sample doesn't have a certain attribute, replace with pad token
|
153 |
-
if text is None:
|
154 |
-
output.append(torch.Tensor([self.pad_idx]))
|
155 |
-
lengths.append(0)
|
156 |
-
continue
|
157 |
-
|
158 |
-
# convert numbers to words
|
159 |
-
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
|
160 |
-
# normalize text
|
161 |
-
text = self.nlp(text) # type: ignore
|
162 |
-
# remove stopwords
|
163 |
-
if self.stopwords:
|
164 |
-
text = [w for w in text if not w.is_stop] # type: ignore
|
165 |
-
# remove punctuation
|
166 |
-
text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
|
167 |
-
# lemmatize if needed
|
168 |
-
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
|
169 |
-
|
170 |
-
texts[i] = " ".join(text)
|
171 |
-
lengths.append(len(text))
|
172 |
-
# convert to tensor
|
173 |
-
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
|
174 |
-
output.append(tokens)
|
175 |
-
|
176 |
-
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
177 |
-
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
|
178 |
-
if return_text:
|
179 |
-
return padded_output, mask, texts # type: ignore
|
180 |
-
return padded_output, mask
|
181 |
-
|
182 |
-
|
183 |
-
class NoopTokenizer(Tokenizer):
|
184 |
-
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
|
185 |
-
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
|
186 |
-
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
|
187 |
-
split it to ["Jeff", "Buckley"] and return an index per word.
|
188 |
-
|
189 |
-
For example:
|
190 |
-
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
|
191 |
-
["Metal", "Rock", "Classical"] => [0, 223, 51]
|
192 |
-
"""
|
193 |
-
def __init__(self, n_bins: int, pad_idx: int = 0):
|
194 |
-
self.n_bins = n_bins
|
195 |
-
self.pad_idx = pad_idx
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
for text in texts:
|
200 |
-
# if current sample doesn't have a certain attribute, replace with pad token
|
201 |
-
if text is None:
|
202 |
-
output.append(self.pad_idx)
|
203 |
-
lengths.append(0)
|
204 |
-
else:
|
205 |
-
output.append(hash_trick(text, self.n_bins))
|
206 |
-
lengths.append(1)
|
207 |
-
|
208 |
-
tokens = torch.LongTensor(output).unsqueeze(1)
|
209 |
-
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
210 |
-
return tokens, mask
|
211 |
|
212 |
|
213 |
class BaseConditioner(nn.Module):
|
|
|
1 |
from collections import defaultdict
|
|
|
2 |
from dataclasses import dataclass, field
|
3 |
from itertools import chain
|
4 |
import logging
|
|
|
9 |
import typing as tp
|
10 |
import warnings
|
11 |
import soundfile
|
|
|
|
|
12 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
13 |
import torch
|
14 |
from torch import nn
|
|
|
|
|
15 |
from .streaming import StreamingModule
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
18 |
from .quantization import ResidualVectorQuantizer
|
19 |
from .utils.autocast import TorchAutocast
|
20 |
from .utils.cache import EmbeddingCache
|
|
|
103 |
raise NotImplementedError()
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
|
109 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
|
112 |
class BaseConditioner(nn.Module):
|
audiocraft/encodec.py
CHANGED
@@ -256,251 +256,4 @@ class EncodecModel(CompressionModel):
|
|
256 |
|
257 |
def decode_latent(self, codes: torch.Tensor):
|
258 |
"""Decode from the discrete codes to continuous latent space."""
|
259 |
-
return self.quantizer.decode(codes)
|
260 |
-
|
261 |
-
|
262 |
-
class DAC(CompressionModel):
|
263 |
-
def __init__(self, model_type: str = "44khz"):
|
264 |
-
super().__init__()
|
265 |
-
try:
|
266 |
-
import dac.utils
|
267 |
-
except ImportError:
|
268 |
-
raise RuntimeError("Could not import dac, make sure it is installed, "
|
269 |
-
"please run `pip install descript-audio-codec`")
|
270 |
-
self.model = dac.utils.load_model(model_type=model_type)
|
271 |
-
self.n_quantizers = self.total_codebooks
|
272 |
-
self.model.eval()
|
273 |
-
|
274 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
275 |
-
# We don't support training with this.
|
276 |
-
raise NotImplementedError("Forward and training with DAC not supported.")
|
277 |
-
|
278 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
279 |
-
codes = self.model.encode(x, self.n_quantizers)[1]
|
280 |
-
return codes[:, :self.n_quantizers], None
|
281 |
-
|
282 |
-
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
283 |
-
assert scale is None
|
284 |
-
z_q = self.decode_latent(codes)
|
285 |
-
return self.model.decode(z_q)
|
286 |
-
|
287 |
-
def decode_latent(self, codes: torch.Tensor):
|
288 |
-
"""Decode from the discrete codes to continuous latent space."""
|
289 |
-
return self.model.quantizer.from_codes(codes)[0]
|
290 |
-
|
291 |
-
@property
|
292 |
-
def channels(self) -> int:
|
293 |
-
return 1
|
294 |
-
|
295 |
-
@property
|
296 |
-
def frame_rate(self) -> float:
|
297 |
-
return self.model.sample_rate / self.model.hop_length
|
298 |
-
|
299 |
-
@property
|
300 |
-
def sample_rate(self) -> int:
|
301 |
-
return self.model.sample_rate
|
302 |
-
|
303 |
-
@property
|
304 |
-
def cardinality(self) -> int:
|
305 |
-
return self.model.codebook_size
|
306 |
-
|
307 |
-
@property
|
308 |
-
def num_codebooks(self) -> int:
|
309 |
-
return self.n_quantizers
|
310 |
-
|
311 |
-
@property
|
312 |
-
def total_codebooks(self) -> int:
|
313 |
-
return self.model.n_codebooks
|
314 |
-
|
315 |
-
def set_num_codebooks(self, n: int):
|
316 |
-
"""Set the active number of codebooks used by the quantizer.
|
317 |
-
"""
|
318 |
-
assert n >= 1
|
319 |
-
assert n <= self.total_codebooks
|
320 |
-
self.n_quantizers = n
|
321 |
-
|
322 |
-
|
323 |
-
class HFEncodecCompressionModel(CompressionModel):
|
324 |
-
"""Wrapper around HuggingFace Encodec.
|
325 |
-
"""
|
326 |
-
def __init__(self, model: HFEncodecModel):
|
327 |
-
super().__init__()
|
328 |
-
self.model = model
|
329 |
-
bws = self.model.config.target_bandwidths
|
330 |
-
num_codebooks = [
|
331 |
-
bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
|
332 |
-
for bw in bws
|
333 |
-
]
|
334 |
-
deltas = [nc - int(nc) for nc in num_codebooks]
|
335 |
-
# Checking we didn't do some bad maths and we indeed have integers!
|
336 |
-
assert all(deltas) <= 1e-3, deltas
|
337 |
-
self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
|
338 |
-
self.set_num_codebooks(max(self.possible_num_codebooks))
|
339 |
-
|
340 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
341 |
-
# We don't support training with this.
|
342 |
-
raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
|
343 |
-
|
344 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
345 |
-
bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
|
346 |
-
bandwidth = self.model.config.target_bandwidths[bandwidth_index]
|
347 |
-
res = self.model.encode(x, None, bandwidth)
|
348 |
-
assert len(res[0]) == 1
|
349 |
-
assert len(res[1]) == 1
|
350 |
-
return res[0][0], res[1][0]
|
351 |
-
|
352 |
-
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
353 |
-
if scale is None:
|
354 |
-
scales = [None] # type: ignore
|
355 |
-
else:
|
356 |
-
scales = scale # type: ignore
|
357 |
-
res = self.model.decode(codes[None], scales)
|
358 |
-
return res[0]
|
359 |
-
|
360 |
-
def decode_latent(self, codes: torch.Tensor):
|
361 |
-
"""Decode from the discrete codes to continuous latent space."""
|
362 |
-
return self.model.quantizer.decode(codes.transpose(0, 1))
|
363 |
-
|
364 |
-
@property
|
365 |
-
def channels(self) -> int:
|
366 |
-
return self.model.config.audio_channels
|
367 |
-
|
368 |
-
@property
|
369 |
-
def frame_rate(self) -> float:
|
370 |
-
hop_length = int(np.prod(self.model.config.upsampling_ratios))
|
371 |
-
return self.sample_rate / hop_length
|
372 |
-
|
373 |
-
@property
|
374 |
-
def sample_rate(self) -> int:
|
375 |
-
return self.model.config.sampling_rate
|
376 |
-
|
377 |
-
@property
|
378 |
-
def cardinality(self) -> int:
|
379 |
-
return self.model.config.codebook_size
|
380 |
-
|
381 |
-
@property
|
382 |
-
def num_codebooks(self) -> int:
|
383 |
-
return self._num_codebooks
|
384 |
-
|
385 |
-
@property
|
386 |
-
def total_codebooks(self) -> int:
|
387 |
-
return max(self.possible_num_codebooks)
|
388 |
-
|
389 |
-
def set_num_codebooks(self, n: int):
|
390 |
-
"""Set the active number of codebooks used by the quantizer.
|
391 |
-
"""
|
392 |
-
if n not in self.possible_num_codebooks:
|
393 |
-
raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
|
394 |
-
self._num_codebooks = n
|
395 |
-
|
396 |
-
|
397 |
-
class InterleaveStereoCompressionModel(CompressionModel):
|
398 |
-
"""Wraps a CompressionModel to support stereo inputs. The wrapped model
|
399 |
-
will be applied independently to the left and right channels, and both codebooks
|
400 |
-
will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
|
401 |
-
channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
|
402 |
-
`per_timestep`.
|
403 |
-
|
404 |
-
Args:
|
405 |
-
model (CompressionModel): Compression model to wrap.
|
406 |
-
per_timestep (bool): Whether to interleave on the timestep dimension
|
407 |
-
or on the codebooks dimension.
|
408 |
-
"""
|
409 |
-
def __init__(self, model: CompressionModel, per_timestep: bool = False):
|
410 |
-
super().__init__()
|
411 |
-
self.model = model
|
412 |
-
self.per_timestep = per_timestep
|
413 |
-
assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
|
414 |
-
|
415 |
-
@property
|
416 |
-
def total_codebooks(self):
|
417 |
-
return self.model.total_codebooks
|
418 |
-
|
419 |
-
@property
|
420 |
-
def num_codebooks(self):
|
421 |
-
"""Active number of codebooks used by the quantizer.
|
422 |
-
|
423 |
-
..Warning:: this reports the number of codebooks after the interleaving
|
424 |
-
of the codebooks!
|
425 |
-
"""
|
426 |
-
return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
|
427 |
-
|
428 |
-
def set_num_codebooks(self, n: int):
|
429 |
-
"""Set the active number of codebooks used by the quantizer.
|
430 |
-
|
431 |
-
..Warning:: this sets the number of codebooks before the interleaving!
|
432 |
-
"""
|
433 |
-
self.model.set_num_codebooks(n)
|
434 |
-
|
435 |
-
@property
|
436 |
-
def num_virtual_steps(self) -> float:
|
437 |
-
"""Return the number of virtual steps, e.g. one real step
|
438 |
-
will be split into that many steps.
|
439 |
-
"""
|
440 |
-
return 2 if self.per_timestep else 1
|
441 |
-
|
442 |
-
@property
|
443 |
-
def frame_rate(self) -> float:
|
444 |
-
return self.model.frame_rate * self.num_virtual_steps
|
445 |
-
|
446 |
-
@property
|
447 |
-
def sample_rate(self) -> int:
|
448 |
-
return self.model.sample_rate
|
449 |
-
|
450 |
-
@property
|
451 |
-
def channels(self) -> int:
|
452 |
-
return 2
|
453 |
-
|
454 |
-
@property
|
455 |
-
def cardinality(self):
|
456 |
-
"""Cardinality of each codebook.
|
457 |
-
"""
|
458 |
-
return self.model.cardinality
|
459 |
-
|
460 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
461 |
-
raise NotImplementedError("Not supported, use encode and decode.")
|
462 |
-
|
463 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
464 |
-
B, C, T = x.shape
|
465 |
-
assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
|
466 |
-
|
467 |
-
indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
|
468 |
-
indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
|
469 |
-
indices = torch.stack([indices_c0, indices_c1], dim=0)
|
470 |
-
scales: tp.Optional[torch.Tensor] = None
|
471 |
-
if scales_c0 is not None and scales_c1 is not None:
|
472 |
-
scales = torch.stack([scales_c0, scales_c1], dim=1)
|
473 |
-
|
474 |
-
if self.per_timestep:
|
475 |
-
indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
|
476 |
-
else:
|
477 |
-
indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
|
478 |
-
|
479 |
-
return (indices, scales)
|
480 |
-
|
481 |
-
def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
482 |
-
if self.per_timestep:
|
483 |
-
codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
|
484 |
-
else:
|
485 |
-
codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
|
486 |
-
return codes[0], codes[1]
|
487 |
-
|
488 |
-
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
489 |
-
B, K, T = codes.shape
|
490 |
-
assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
|
491 |
-
assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
|
492 |
-
|
493 |
-
scale_c0, scale_c1 = None, None
|
494 |
-
if scale is not None:
|
495 |
-
assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
|
496 |
-
scale_c0 = scale[0, ...]
|
497 |
-
scale_c1 = scale[1, ...]
|
498 |
-
|
499 |
-
codes_c0, codes_c1 = self.get_left_right_codes(codes)
|
500 |
-
audio_c0 = self.model.decode(codes_c0, scale_c0)
|
501 |
-
audio_c1 = self.model.decode(codes_c1, scale_c1)
|
502 |
-
return torch.cat([audio_c0, audio_c1], dim=1)
|
503 |
-
|
504 |
-
def decode_latent(self, codes: torch.Tensor):
|
505 |
-
"""Decode from the discrete codes to continuous latent space."""
|
506 |
-
raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
|
|
|
256 |
|
257 |
def decode_latent(self, codes: torch.Tensor):
|
258 |
"""Decode from the discrete codes to continuous latent space."""
|
259 |
+
return self.quantizer.decode(codes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/genmodel.py
CHANGED
@@ -6,7 +6,6 @@ import torch
|
|
6 |
|
7 |
from .encodec import CompressionModel
|
8 |
from .lm import LMModel
|
9 |
-
from .builders import get_wrapped_compression_model
|
10 |
from .utils.audio_utils import convert_audio
|
11 |
from .conditioners import ConditioningAttributes
|
12 |
from .utils.autocast import TorchAutocast
|
@@ -38,9 +37,6 @@ class BaseGenModel(ABC):
|
|
38 |
assert isinstance(cfg, omegaconf.DictConfig)
|
39 |
self.cfg = cfg
|
40 |
|
41 |
-
if self.cfg is not None:
|
42 |
-
self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
|
43 |
-
|
44 |
if max_duration is None:
|
45 |
if self.cfg is not None:
|
46 |
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
|
|
6 |
|
7 |
from .encodec import CompressionModel
|
8 |
from .lm import LMModel
|
|
|
9 |
from .utils.audio_utils import convert_audio
|
10 |
from .conditioners import ConditioningAttributes
|
11 |
from .utils.autocast import TorchAutocast
|
|
|
37 |
assert isinstance(cfg, omegaconf.DictConfig)
|
38 |
self.cfg = cfg
|
39 |
|
|
|
|
|
|
|
40 |
if max_duration is None:
|
41 |
if self.cfg is not None:
|
42 |
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
audiocraft/multibanddiffusion.py
DELETED
@@ -1,392 +0,0 @@
|
|
1 |
-
#====================================== From CompressionSolver.py
|
2 |
-
|
3 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
-
# All rights reserved.
|
5 |
-
#
|
6 |
-
# This source code is licensed under the license found in the
|
7 |
-
# LICENSE file in the root directory of this source tree.
|
8 |
-
|
9 |
-
import logging
|
10 |
-
import multiprocessing
|
11 |
-
from pathlib import Path
|
12 |
-
import typing as tp
|
13 |
-
|
14 |
-
import flashy
|
15 |
-
import omegaconf
|
16 |
-
import torch
|
17 |
-
from torch import nn
|
18 |
-
|
19 |
-
# from . import base, builders
|
20 |
-
from .. import models, quantization
|
21 |
-
from ..utils import checkpoint
|
22 |
-
from ..utils.samples.manager import SampleManager
|
23 |
-
from ..utils.utils import get_pool_executor
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
class CompressionSolver(): #base.StandardSolver):
|
30 |
-
"""Solver for compression task.
|
31 |
-
|
32 |
-
The compression task combines a set of perceptual and objective losses
|
33 |
-
to train an EncodecModel (composed of an encoder-decoder and a quantizer)
|
34 |
-
to perform high fidelity audio reconstruction.
|
35 |
-
"""
|
36 |
-
def __init__(self, cfg: omegaconf.DictConfig):
|
37 |
-
# super().__init__(cfg)
|
38 |
-
self.cfg = cfg
|
39 |
-
self.rng: torch.Generator # set at each epoch
|
40 |
-
self.adv_losses = builders.get_adversarial_losses(self.cfg)
|
41 |
-
self.aux_losses = nn.ModuleDict()
|
42 |
-
self.info_losses = nn.ModuleDict()
|
43 |
-
assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
|
44 |
-
loss_weights = dict()
|
45 |
-
for loss_name, weight in self.cfg.losses.items():
|
46 |
-
if loss_name in ['adv', 'feat']:
|
47 |
-
for adv_name, _ in self.adv_losses.items():
|
48 |
-
loss_weights[f'{loss_name}_{adv_name}'] = weight
|
49 |
-
elif weight > 0:
|
50 |
-
self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
|
51 |
-
loss_weights[loss_name] = weight
|
52 |
-
else:
|
53 |
-
self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
|
54 |
-
self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)
|
55 |
-
self.register_stateful('adv_losses')
|
56 |
-
|
57 |
-
@property
|
58 |
-
def best_metric_name(self) -> tp.Optional[str]:
|
59 |
-
# best model is the last for the compression model
|
60 |
-
return None
|
61 |
-
|
62 |
-
def build_model(self):
|
63 |
-
"""Instantiate model and optimizer."""
|
64 |
-
# Model and optimizer
|
65 |
-
self.model = models.builders.get_compression_model(self.cfg).to(self.device)
|
66 |
-
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
|
67 |
-
self.register_stateful('model', 'optimizer')
|
68 |
-
self.register_best_state('model')
|
69 |
-
self.register_ema('model')
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
def evaluate(self):
|
74 |
-
"""Evaluate stage. Runs audio reconstruction evaluation."""
|
75 |
-
self.model.eval()
|
76 |
-
evaluate_stage_name = str(self.current_stage)
|
77 |
-
|
78 |
-
loader = self.dataloaders['evaluate']
|
79 |
-
updates = len(loader)
|
80 |
-
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
|
81 |
-
average = flashy.averager()
|
82 |
-
|
83 |
-
pendings = []
|
84 |
-
ctx = multiprocessing.get_context('spawn')
|
85 |
-
with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
|
86 |
-
for idx, batch in enumerate(lp):
|
87 |
-
x = batch.to(self.device)
|
88 |
-
with torch.no_grad():
|
89 |
-
qres = self.model(x)
|
90 |
-
|
91 |
-
y_pred = qres.x.cpu()
|
92 |
-
y = batch.cpu() # should already be on CPU but just in case
|
93 |
-
pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
|
94 |
-
|
95 |
-
metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
|
96 |
-
for pending in metrics_lp:
|
97 |
-
metrics = pending.result()
|
98 |
-
metrics = average(metrics)
|
99 |
-
|
100 |
-
metrics = flashy.distrib.average_metrics(metrics, len(loader))
|
101 |
-
return metrics
|
102 |
-
|
103 |
-
def generate(self):
|
104 |
-
"""Generate stage."""
|
105 |
-
self.model.eval()
|
106 |
-
sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
|
107 |
-
generate_stage_name = str(self.current_stage)
|
108 |
-
|
109 |
-
loader = self.dataloaders['generate']
|
110 |
-
updates = len(loader)
|
111 |
-
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
|
112 |
-
|
113 |
-
for batch in lp:
|
114 |
-
reference, _ = batch
|
115 |
-
reference = reference.to(self.device)
|
116 |
-
with torch.no_grad():
|
117 |
-
qres = self.model(reference)
|
118 |
-
assert isinstance(qres, quantization.QuantizedResult)
|
119 |
-
|
120 |
-
reference = reference.cpu()
|
121 |
-
estimate = qres.x.cpu()
|
122 |
-
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
|
123 |
-
|
124 |
-
flashy.distrib.barrier()
|
125 |
-
|
126 |
-
def load_from_pretrained(self, name: str) -> dict:
|
127 |
-
model = models.CompressionModel.get_pretrained(name)
|
128 |
-
if isinstance(model, models.DAC):
|
129 |
-
raise RuntimeError("Cannot fine tune a DAC model.")
|
130 |
-
elif isinstance(model, models.HFEncodecCompressionModel):
|
131 |
-
self.logger.warning('Trying to automatically convert a HuggingFace model '
|
132 |
-
'to AudioCraft, this might fail!')
|
133 |
-
state = model.model.state_dict()
|
134 |
-
new_state = {}
|
135 |
-
for k, v in state.items():
|
136 |
-
if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
|
137 |
-
# We need to determine if this a convtr or a regular conv.
|
138 |
-
layer = int(k.split('.')[2])
|
139 |
-
if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
|
140 |
-
|
141 |
-
k = k.replace('.conv.', '.convtr.')
|
142 |
-
k = k.replace('encoder.layers.', 'encoder.model.')
|
143 |
-
k = k.replace('decoder.layers.', 'decoder.model.')
|
144 |
-
k = k.replace('conv.', 'conv.conv.')
|
145 |
-
k = k.replace('convtr.', 'convtr.convtr.')
|
146 |
-
k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
|
147 |
-
k = k.replace('.codebook.', '._codebook.')
|
148 |
-
new_state[k] = v
|
149 |
-
state = new_state
|
150 |
-
elif isinstance(model, models.EncodecModel):
|
151 |
-
state = model.state_dict()
|
152 |
-
else:
|
153 |
-
raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
|
154 |
-
return {
|
155 |
-
'best_state': {'model': state}
|
156 |
-
}
|
157 |
-
|
158 |
-
@staticmethod
|
159 |
-
def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
|
160 |
-
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
|
161 |
-
"""Instantiate a CompressionModel from a given checkpoint path or dora sig.
|
162 |
-
This method is a convenient endpoint to load a CompressionModel to use in other solvers.
|
163 |
-
|
164 |
-
Args:
|
165 |
-
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
|
166 |
-
This also supports pre-trained models by using a path of the form //pretrained/NAME.
|
167 |
-
See `model_from_pretrained` for a list of supported pretrained models.
|
168 |
-
use_ema (bool): Use EMA variant of the model instead of the actual model.
|
169 |
-
device (torch.device or str): Device on which the model is loaded.
|
170 |
-
"""
|
171 |
-
checkpoint_path = str(checkpoint_path)
|
172 |
-
if checkpoint_path.startswith('//pretrained/'):
|
173 |
-
name = checkpoint_path.split('/', 3)[-1]
|
174 |
-
return models.CompressionModel.get_pretrained(name, device)
|
175 |
-
logger = logging.getLogger(__name__)
|
176 |
-
logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
|
177 |
-
_checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
|
178 |
-
assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
|
179 |
-
state = checkpoint.load_checkpoint(_checkpoint_path)
|
180 |
-
assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
|
181 |
-
cfg = state['xp.cfg']
|
182 |
-
cfg.device = device
|
183 |
-
compression_model = models.builders.get_compression_model(cfg).to(device)
|
184 |
-
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
|
185 |
-
|
186 |
-
assert 'best_state' in state and state['best_state'] != {}
|
187 |
-
assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
|
188 |
-
compression_model.load_state_dict(state['best_state']['model'])
|
189 |
-
compression_model.eval()
|
190 |
-
logger.info("Compression model loaded!")
|
191 |
-
return compression_model
|
192 |
-
|
193 |
-
@staticmethod
|
194 |
-
def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
|
195 |
-
checkpoint_path: tp.Union[Path, str],
|
196 |
-
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
|
197 |
-
"""Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
|
198 |
-
|
199 |
-
Args:
|
200 |
-
cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
|
201 |
-
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
|
202 |
-
use_ema (bool): Use EMA variant of the model instead of the actual model.
|
203 |
-
device (torch.device or str): Device on which the model is loaded.
|
204 |
-
"""
|
205 |
-
compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
|
206 |
-
compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
|
207 |
-
return compression_model
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
#=========================================================================== ORIG
|
214 |
-
|
215 |
-
import typing as tp
|
216 |
-
|
217 |
-
import torch
|
218 |
-
import julius
|
219 |
-
|
220 |
-
from .unet import DiffusionUnet
|
221 |
-
from ..modules.diffusion_schedule import NoiseSchedule
|
222 |
-
from .encodec import CompressionModel
|
223 |
-
from .loaders import load_compression_model, load_diffusion_models
|
224 |
-
|
225 |
-
|
226 |
-
class DiffusionProcess:
|
227 |
-
"""Sampling for a diffusion Model.
|
228 |
-
|
229 |
-
Args:
|
230 |
-
model (DiffusionUnet): Diffusion U-Net model.
|
231 |
-
noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
|
232 |
-
"""
|
233 |
-
def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
|
234 |
-
self.model = model
|
235 |
-
self.schedule = noise_schedule
|
236 |
-
|
237 |
-
def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
|
238 |
-
step_list: tp.Optional[tp.List[int]] = None):
|
239 |
-
"""Perform one diffusion process to generate one of the bands.
|
240 |
-
|
241 |
-
Args:
|
242 |
-
condition (torch.Tensor): The embeddings from the compression model.
|
243 |
-
initial_noise (torch.Tensor): The initial noise to start the process.
|
244 |
-
"""
|
245 |
-
return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
|
246 |
-
condition=condition)
|
247 |
-
|
248 |
-
|
249 |
-
class MultiBandDiffusion:
|
250 |
-
"""Sample from multiple diffusion models.
|
251 |
-
|
252 |
-
Args:
|
253 |
-
DPs (list of DiffusionProcess): Diffusion processes.
|
254 |
-
codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
|
255 |
-
"""
|
256 |
-
def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
|
257 |
-
self.DPs = DPs
|
258 |
-
self.codec_model = codec_model
|
259 |
-
self.device = next(self.codec_model.parameters()).device
|
260 |
-
|
261 |
-
@property
|
262 |
-
def sample_rate(self) -> int:
|
263 |
-
return self.codec_model.sample_rate
|
264 |
-
|
265 |
-
@staticmethod
|
266 |
-
def get_mbd_musicgen(device=None):
|
267 |
-
"""Load our diffusion models trained for MusicGen."""
|
268 |
-
if device is None:
|
269 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
270 |
-
path = 'facebook/multiband-diffusion'
|
271 |
-
filename = 'mbd_musicgen_32khz.th'
|
272 |
-
name = 'facebook/musicgen-small'
|
273 |
-
codec_model = load_compression_model(name, device=device)
|
274 |
-
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
|
275 |
-
DPs = []
|
276 |
-
for i in range(len(models)):
|
277 |
-
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
|
278 |
-
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
|
279 |
-
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
|
280 |
-
|
281 |
-
@staticmethod
|
282 |
-
def get_mbd_24khz(bw: float = 3.0,
|
283 |
-
device: tp.Optional[tp.Union[torch.device, str]] = None,
|
284 |
-
n_q: tp.Optional[int] = None):
|
285 |
-
"""Get the pretrained Models for MultibandDiffusion.
|
286 |
-
|
287 |
-
Args:
|
288 |
-
bw (float): Bandwidth of the compression model.
|
289 |
-
device (torch.device or str, optional): Device on which the models are loaded.
|
290 |
-
n_q (int, optional): Number of quantizers to use within the compression model.
|
291 |
-
"""
|
292 |
-
if device is None:
|
293 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
294 |
-
assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
|
295 |
-
if n_q is not None:
|
296 |
-
assert n_q in [2, 4, 8]
|
297 |
-
assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
|
298 |
-
f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
|
299 |
-
n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
|
300 |
-
codec_model = CompressionSolver.model_from_checkpoint(
|
301 |
-
'//pretrained/facebook/encodec_24khz', device=device)
|
302 |
-
codec_model.set_num_codebooks(n_q)
|
303 |
-
codec_model = codec_model.to(device)
|
304 |
-
path = 'facebook/multiband-diffusion'
|
305 |
-
filename = f'mbd_comp_{n_q}.pt'
|
306 |
-
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
|
307 |
-
DPs = []
|
308 |
-
for i in range(len(models)):
|
309 |
-
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
|
310 |
-
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
|
311 |
-
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
|
312 |
-
|
313 |
-
@torch.no_grad()
|
314 |
-
def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
315 |
-
"""Get the conditioning (i.e. latent representations of the compression model) from a waveform.
|
316 |
-
Args:
|
317 |
-
wav (torch.Tensor): The audio that we want to extract the conditioning from.
|
318 |
-
sample_rate (int): Sample rate of the audio."""
|
319 |
-
if sample_rate != self.sample_rate:
|
320 |
-
wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
|
321 |
-
codes, scale = self.codec_model.encode(wav)
|
322 |
-
assert scale is None, "Scaled compression models not supported."
|
323 |
-
emb = self.get_emb(codes)
|
324 |
-
return emb
|
325 |
-
|
326 |
-
@torch.no_grad()
|
327 |
-
def get_emb(self, codes: torch.Tensor):
|
328 |
-
"""Get latent representation from the discrete codes.
|
329 |
-
Args:
|
330 |
-
codes (torch.Tensor): Discrete tokens."""
|
331 |
-
emb = self.codec_model.decode_latent(codes)
|
332 |
-
return emb
|
333 |
-
|
334 |
-
def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
|
335 |
-
step_list: tp.Optional[tp.List[int]] = None):
|
336 |
-
"""Generate waveform audio from the latent embeddings of the compression model.
|
337 |
-
Args:
|
338 |
-
emb (torch.Tensor): Conditioning embeddings
|
339 |
-
size (None, torch.Size): Size of the output
|
340 |
-
if None this is computed from the typical upsampling of the model.
|
341 |
-
step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
|
342 |
-
"""
|
343 |
-
if size is None:
|
344 |
-
upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
|
345 |
-
size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
|
346 |
-
assert size[0] == emb.size(0)
|
347 |
-
out = torch.zeros(size).to(self.device)
|
348 |
-
for DP in self.DPs:
|
349 |
-
out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
|
350 |
-
return out
|
351 |
-
|
352 |
-
def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
|
353 |
-
"""Match the eq to the encodec output by matching the standard deviation of some frequency bands.
|
354 |
-
Args:
|
355 |
-
wav (torch.Tensor): Audio to equalize.
|
356 |
-
ref (torch.Tensor): Reference audio from which we match the spectrogram.
|
357 |
-
n_bands (int): Number of bands of the eq.
|
358 |
-
strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
|
359 |
-
"""
|
360 |
-
split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
|
361 |
-
bands = split(wav)
|
362 |
-
bands_ref = split(ref)
|
363 |
-
out = torch.zeros_like(ref)
|
364 |
-
for i in range(n_bands):
|
365 |
-
out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
|
366 |
-
return out
|
367 |
-
|
368 |
-
def regenerate(self, wav: torch.Tensor, sample_rate: int):
|
369 |
-
"""Regenerate a waveform through compression and diffusion regeneration.
|
370 |
-
Args:
|
371 |
-
wav (torch.Tensor): Original 'ground truth' audio.
|
372 |
-
sample_rate (int): Sample rate of the input (and output) wav.
|
373 |
-
"""
|
374 |
-
if sample_rate != self.codec_model.sample_rate:
|
375 |
-
wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
|
376 |
-
emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
|
377 |
-
size = wav.size()
|
378 |
-
out = self.generate(emb, size=size)
|
379 |
-
if sample_rate != self.codec_model.sample_rate:
|
380 |
-
out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
|
381 |
-
return out
|
382 |
-
|
383 |
-
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
|
384 |
-
"""Generate Waveform audio with diffusion from the discrete codes.
|
385 |
-
Args:
|
386 |
-
tokens (torch.Tensor): Discrete codes.
|
387 |
-
n_bands (int): Bands for the eq matching.
|
388 |
-
"""
|
389 |
-
wav_encodec = self.codec_model.decode(tokens)
|
390 |
-
condition = self.get_emb(tokens)
|
391 |
-
wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
|
392 |
-
return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.py
CHANGED
@@ -12,4 +12,4 @@ sound_generator.set_generation_params(duration=1) # why is generating so long
|
|
12 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
x /= np.abs(x).max() + 1e-7
|
14 |
|
15 |
-
audiofile.write('
|
|
|
12 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
x /= np.abs(x).max() + 1e-7
|
14 |
|
15 |
+
audiofile.write('_audio3_.wav', x, 16000)
|