clean unused functions
Browse files- audiocraft/builders.py +5 -41
- audiocraft/codebooks_patterns.py +1 -190
- audiocraft/conditioners.py +4 -1
- audiocraft/encodec.py +5 -47
- audiocraft/lm.py +41 -859
- audiocraft/loaders.py +2 -1
- demo.py +4 -5
audiocraft/builders.py
CHANGED
@@ -10,22 +10,13 @@ from the Hydra config.
|
|
10 |
"""
|
11 |
|
12 |
import typing as tp
|
13 |
-
|
14 |
-
import audiocraft
|
15 |
import omegaconf
|
16 |
import torch
|
17 |
|
18 |
from .encodec import CompressionModel, EncodecModel
|
19 |
from .lm import LMModel
|
20 |
from .seanet import SEANetDecoder
|
21 |
-
from .codebooks_patterns import
|
22 |
-
CodebooksPatternProvider,
|
23 |
-
DelayedPatternProvider,
|
24 |
-
MusicLMPattern,
|
25 |
-
ParallelPatternProvider,
|
26 |
-
UnrolledPatternProvider,
|
27 |
-
CoarseFirstPattern,
|
28 |
-
)
|
29 |
from .conditioners import (
|
30 |
BaseConditioner,
|
31 |
ConditionFuser,
|
@@ -159,45 +150,18 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
|
159 |
return fuser
|
160 |
|
161 |
|
162 |
-
def get_codebooks_pattern_provider(n_q
|
163 |
-
"""Instantiate a codebooks pattern provider object."""
|
164 |
pattern_providers = {
|
165 |
-
'
|
166 |
-
'delay': DelayedPatternProvider,
|
167 |
-
'unroll': UnrolledPatternProvider,
|
168 |
-
'coarse_first': CoarseFirstPattern,
|
169 |
-
'musiclm': MusicLMPattern,
|
170 |
}
|
171 |
name = cfg.modeling
|
172 |
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
|
|
173 |
klass = pattern_providers[name]
|
174 |
return klass(n_q, **kwargs)
|
175 |
|
176 |
|
177 |
-
|
178 |
-
"""Instantiate a debug compression model to be used for unit tests."""
|
179 |
-
assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
|
180 |
-
model_ratios = {
|
181 |
-
16000: [10, 8, 8], # 25 Hz at 16kHz
|
182 |
-
32000: [10, 8, 16] # 25 Hz at 32kHz
|
183 |
-
}
|
184 |
-
ratios: tp.List[int] = model_ratios[sample_rate]
|
185 |
-
frame_rate = 25
|
186 |
-
seanet_kwargs: dict = {
|
187 |
-
'n_filters': 4,
|
188 |
-
'n_residual_layers': 1,
|
189 |
-
'dimension': 32,
|
190 |
-
'ratios': ratios,
|
191 |
-
}
|
192 |
-
encoder = SEANetEncoder(**seanet_kwargs)
|
193 |
-
decoder = SEANetDecoder(**seanet_kwargs)
|
194 |
-
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
|
195 |
-
init_x = torch.randn(8, 32, 128)
|
196 |
-
quantizer(init_x, 1) # initialize kmeans etc.
|
197 |
-
compression_model = EncodecModel(
|
198 |
-
encoder, decoder, quantizer,
|
199 |
-
frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
|
200 |
-
return compression_model.eval()
|
201 |
|
202 |
|
203 |
def get_diffusion_model(cfg: omegaconf.DictConfig):
|
|
|
10 |
"""
|
11 |
|
12 |
import typing as tp
|
|
|
|
|
13 |
import omegaconf
|
14 |
import torch
|
15 |
|
16 |
from .encodec import CompressionModel, EncodecModel
|
17 |
from .lm import LMModel
|
18 |
from .seanet import SEANetDecoder
|
19 |
+
from .codebooks_patterns import DelayedPatternProvider
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
from .conditioners import (
|
21 |
BaseConditioner,
|
22 |
ConditionFuser,
|
|
|
150 |
return fuser
|
151 |
|
152 |
|
153 |
+
def get_codebooks_pattern_provider(n_q, cfg):
|
|
|
154 |
pattern_providers = {
|
155 |
+
'delay': DelayedPatternProvider, # THIS
|
|
|
|
|
|
|
|
|
156 |
}
|
157 |
name = cfg.modeling
|
158 |
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
159 |
+
|
160 |
klass = pattern_providers[name]
|
161 |
return klass(n_q, **kwargs)
|
162 |
|
163 |
|
164 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
|
167 |
def get_diffusion_model(cfg: omegaconf.DictConfig):
|
audiocraft/codebooks_patterns.py
CHANGED
@@ -52,7 +52,7 @@ class Pattern:
|
|
52 |
self._validate_layout()
|
53 |
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
54 |
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
55 |
-
|
56 |
|
57 |
def _validate_layout(self):
|
58 |
"""Runs checks on the layout to ensure a valid pattern is defined.
|
@@ -356,193 +356,4 @@ class DelayedPatternProvider(CodebooksPatternProvider):
|
|
356 |
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
357 |
|
358 |
|
359 |
-
class ParallelPatternProvider(DelayedPatternProvider):
|
360 |
-
"""Provider for parallel pattern across codebooks.
|
361 |
-
This pattern provider is a special case of the delayed pattern with actually no delay,
|
362 |
-
hence delays=repeat(0, n_q).
|
363 |
|
364 |
-
Args:
|
365 |
-
n_q (int): Number of codebooks.
|
366 |
-
empty_initial (int): Prepend with N empty list of coordinates.
|
367 |
-
"""
|
368 |
-
def __init__(self, n_q: int, empty_initial: int = 0):
|
369 |
-
super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
|
370 |
-
|
371 |
-
|
372 |
-
class UnrolledPatternProvider(CodebooksPatternProvider):
|
373 |
-
"""Provider for unrolling codebooks pattern.
|
374 |
-
This pattern provider enables to represent the codebook flattened completely or only to some extend
|
375 |
-
while also specifying a given delay between the flattened codebooks representation, allowing to
|
376 |
-
unroll the codebooks in the sequence.
|
377 |
-
|
378 |
-
Example:
|
379 |
-
1. Flattening of the codebooks.
|
380 |
-
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
|
381 |
-
taking n_q = 3 and timesteps = 4:
|
382 |
-
[[1, 2, 3, 4],
|
383 |
-
[1, 2, 3, 4],
|
384 |
-
[1, 2, 3, 4]]
|
385 |
-
will result into:
|
386 |
-
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
|
387 |
-
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
388 |
-
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
389 |
-
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
|
390 |
-
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
|
391 |
-
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
|
392 |
-
[[1, 2, 3, 4],
|
393 |
-
[1, 2, 3, 4],
|
394 |
-
[1, 2, 3, 4]]
|
395 |
-
will result into:
|
396 |
-
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
397 |
-
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
398 |
-
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
399 |
-
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
|
400 |
-
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
|
401 |
-
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
|
402 |
-
and delays = [0, 3, 3]:
|
403 |
-
[[1, 2, 3, 4],
|
404 |
-
[1, 2, 3, 4],
|
405 |
-
[1, 2, 3, 4]]
|
406 |
-
will result into:
|
407 |
-
[[S, S, S, 1, S, 2, S, 3, S, 4],
|
408 |
-
[S, S, S, 1, S, 2, S, 3, S, 4],
|
409 |
-
[1, 2, 3, S, 4, S, 5, S, 6, S]]
|
410 |
-
|
411 |
-
Args:
|
412 |
-
n_q (int): Number of codebooks.
|
413 |
-
flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
|
414 |
-
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
415 |
-
have n_q extra steps for each timestep.
|
416 |
-
delays (list of int, optional): Delay for each of the codebooks. If not defined,
|
417 |
-
no delay is added and therefore will default to [0] * ``n_q``.
|
418 |
-
Note that two codebooks that will be flattened to the same inner step
|
419 |
-
should have the same delay, otherwise the pattern is considered as invalid.
|
420 |
-
"""
|
421 |
-
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
|
422 |
-
|
423 |
-
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
|
424 |
-
delays: tp.Optional[tp.List[int]] = None):
|
425 |
-
super().__init__(n_q)
|
426 |
-
if flattening is None:
|
427 |
-
flattening = list(range(n_q))
|
428 |
-
if delays is None:
|
429 |
-
delays = [0] * n_q
|
430 |
-
assert len(flattening) == n_q
|
431 |
-
assert len(delays) == n_q
|
432 |
-
assert sorted(flattening) == flattening
|
433 |
-
assert sorted(delays) == delays
|
434 |
-
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
|
435 |
-
self.max_delay = max(delays)
|
436 |
-
|
437 |
-
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
|
438 |
-
"""Build a flattened codebooks representation as a dictionary of inner step
|
439 |
-
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
|
440 |
-
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
|
441 |
-
"""
|
442 |
-
flattened_codebooks: dict = {}
|
443 |
-
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
|
444 |
-
if inner_step not in flattened_codebooks:
|
445 |
-
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
|
446 |
-
else:
|
447 |
-
flat_codebook = flattened_codebooks[inner_step]
|
448 |
-
assert flat_codebook.delay == delay, (
|
449 |
-
"Delay and flattening between codebooks is inconsistent: ",
|
450 |
-
"two codebooks flattened to the same position should have the same delay."
|
451 |
-
)
|
452 |
-
flat_codebook.codebooks.append(q)
|
453 |
-
flattened_codebooks[inner_step] = flat_codebook
|
454 |
-
return flattened_codebooks
|
455 |
-
|
456 |
-
@property
|
457 |
-
def _num_inner_steps(self):
|
458 |
-
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
|
459 |
-
"""
|
460 |
-
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
|
461 |
-
|
462 |
-
def num_virtual_steps(self, timesteps: int) -> int:
|
463 |
-
return timesteps * self._num_inner_steps + 1
|
464 |
-
|
465 |
-
def get_pattern(self, timesteps: int) -> Pattern:
|
466 |
-
"""Builds pattern for delay across codebooks.
|
467 |
-
|
468 |
-
Args:
|
469 |
-
timesteps (int): Total number of timesteps.
|
470 |
-
"""
|
471 |
-
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
472 |
-
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
473 |
-
indexed_out: list = [(-1, [])]
|
474 |
-
max_timesteps = timesteps + self.max_delay
|
475 |
-
for t in range(max_timesteps):
|
476 |
-
# for each timestep, we unroll the flattened codebooks,
|
477 |
-
# emitting the sequence step with the corresponding delay
|
478 |
-
for step in range(self._num_inner_steps):
|
479 |
-
if step in self._flattened_codebooks:
|
480 |
-
# we have codebooks at this virtual step to emit
|
481 |
-
step_codebooks = self._flattened_codebooks[step]
|
482 |
-
t_for_q = t + step_codebooks.delay
|
483 |
-
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
|
484 |
-
if t_for_q < max_timesteps and t < max_timesteps:
|
485 |
-
indexed_out.append((t_for_q, coords))
|
486 |
-
else:
|
487 |
-
# there is no codebook in this virtual step so we emit an empty list
|
488 |
-
indexed_out.append((t, []))
|
489 |
-
out = [coords for _, coords in sorted(indexed_out)]
|
490 |
-
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
491 |
-
|
492 |
-
|
493 |
-
class CoarseFirstPattern(CodebooksPatternProvider):
|
494 |
-
"""First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
|
495 |
-
potentially with delays.
|
496 |
-
|
497 |
-
..Warning:: You must always generate the full training duration at test time, for instance,
|
498 |
-
30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
|
499 |
-
location. This is due to the non causality of the remaining codebooks with respect to
|
500 |
-
the first ones.
|
501 |
-
|
502 |
-
Args:
|
503 |
-
n_q (int): Number of codebooks.
|
504 |
-
delays (list of int, optional): Delay for each of the codebooks.
|
505 |
-
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
506 |
-
"""
|
507 |
-
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
508 |
-
super().__init__(n_q)
|
509 |
-
if delays is None:
|
510 |
-
delays = [0] * (n_q - 1)
|
511 |
-
self.delays = delays
|
512 |
-
assert len(self.delays) == self.n_q - 1
|
513 |
-
assert sorted(self.delays) == self.delays
|
514 |
-
|
515 |
-
def get_pattern(self, timesteps: int) -> Pattern:
|
516 |
-
out: PatternLayout = [[]]
|
517 |
-
for t in range(timesteps):
|
518 |
-
out.append([LayoutCoord(t, 0)])
|
519 |
-
max_delay = max(self.delays)
|
520 |
-
for t in range(timesteps + max_delay):
|
521 |
-
v = []
|
522 |
-
for q, delay in enumerate(self.delays):
|
523 |
-
t_for_q = t - delay
|
524 |
-
if t_for_q >= 0:
|
525 |
-
v.append(LayoutCoord(t_for_q, q + 1))
|
526 |
-
out.append(v)
|
527 |
-
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
528 |
-
|
529 |
-
|
530 |
-
class MusicLMPattern(CodebooksPatternProvider):
|
531 |
-
"""Almost MusicLM style pattern. This is equivalent to full flattening
|
532 |
-
but in a different order.
|
533 |
-
|
534 |
-
Args:
|
535 |
-
n_q (int): Number of codebooks.
|
536 |
-
group_by (int): Number of codebooks to group together.
|
537 |
-
"""
|
538 |
-
def __init__(self, n_q: int, group_by: int = 2):
|
539 |
-
super().__init__(n_q)
|
540 |
-
self.group_by = group_by
|
541 |
-
|
542 |
-
def get_pattern(self, timesteps: int) -> Pattern:
|
543 |
-
out: PatternLayout = [[]]
|
544 |
-
for offset in range(0, self.n_q, self.group_by):
|
545 |
-
for t in range(timesteps):
|
546 |
-
for q in range(offset, offset + self.group_by):
|
547 |
-
out.append([LayoutCoord(t, q)])
|
548 |
-
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
|
|
52 |
self._validate_layout()
|
53 |
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
54 |
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
55 |
+
print("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
56 |
|
57 |
def _validate_layout(self):
|
58 |
"""Runs checks on the layout to ensure a valid pattern is defined.
|
|
|
356 |
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
357 |
|
358 |
|
|
|
|
|
|
|
|
|
359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/conditioners.py
CHANGED
@@ -410,7 +410,10 @@ class ConditionFuser(StreamingModule):
|
|
410 |
# print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'}
|
411 |
|
412 |
cross_attention_output = cond
|
413 |
-
|
|
|
|
|
|
|
414 |
|
415 |
if self._is_streaming:
|
416 |
self._streaming_state['offsets'] = offsets + T
|
|
|
410 |
# print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'}
|
411 |
|
412 |
cross_attention_output = cond
|
413 |
+
# print(f'{cross_attention_output.shape=} for {input.sum()=}')
|
414 |
+
# cross_attention_output.shape=torch.Size([2, 5, 1536]) for input.sum()=tensor(-0.0650, device='cuda:0')
|
415 |
+
# cross_attention_output.shape=torch.Size([2, 5, 1536]) for input.sum()=tensor(3.7672, device='cuda:0')
|
416 |
+
|
417 |
|
418 |
if self._is_streaming:
|
419 |
self._streaming_state['offsets'] = offsets + T
|
audiocraft/encodec.py
CHANGED
@@ -77,42 +77,7 @@ class CompressionModel(ABC, nn.Module):
|
|
77 |
"""Set the active number of codebooks used by the quantizer."""
|
78 |
...
|
79 |
|
80 |
-
|
81 |
-
def get_pretrained(
|
82 |
-
name: str, device: tp.Union[torch.device, str] = 'cpu'
|
83 |
-
) -> 'CompressionModel':
|
84 |
-
"""Instantiate a CompressionModel from a given pretrained model.
|
85 |
-
|
86 |
-
Args:
|
87 |
-
name (Path or str): name of the pretrained model. See after.
|
88 |
-
device (torch.device or str): Device on which the model is loaded.
|
89 |
-
|
90 |
-
Pretrained models:
|
91 |
-
- dac_44khz (https://github.com/descriptinc/descript-audio-codec)
|
92 |
-
- dac_24khz (same)
|
93 |
-
- facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
|
94 |
-
- facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
|
95 |
-
- your own model on Hugging Face. Export instructions to come...
|
96 |
-
"""
|
97 |
-
|
98 |
-
from . import builders, loaders
|
99 |
-
model: CompressionModel
|
100 |
-
if name in ['dac_44khz', 'dac_24khz']:
|
101 |
-
model_type = name.split('_')[1]
|
102 |
-
logger.info("Getting pretrained compression model from DAC %s", model_type)
|
103 |
-
model = DAC(model_type)
|
104 |
-
elif name in ['debug_compression_model']:
|
105 |
-
logger.info("Getting pretrained compression model for debug")
|
106 |
-
model = builders.get_debug_compression_model()
|
107 |
-
elif Path(name).exists():
|
108 |
-
# We assume here if the path exists that it is in fact an AC checkpoint
|
109 |
-
# that was exported using `audiocraft.utils.export` functions.
|
110 |
-
model = loaders.load_compression_model(name, device=device)
|
111 |
-
else:
|
112 |
-
logger.info("Getting pretrained compression model from HF %s", name)
|
113 |
-
hf_model = HFEncodecModel.from_pretrained(name)
|
114 |
-
model = HFEncodecCompressionModel(hf_model).to(device)
|
115 |
-
return model.to(device).eval()
|
116 |
|
117 |
|
118 |
class EncodecModel(CompressionModel):
|
@@ -196,20 +161,13 @@ class EncodecModel(CompressionModel):
|
|
196 |
return x
|
197 |
|
198 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
199 |
-
|
200 |
-
audio denormalization if needed.
|
201 |
-
|
202 |
-
Args:
|
203 |
-
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
204 |
-
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
205 |
-
|
206 |
-
Returns:
|
207 |
-
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
208 |
-
"""
|
209 |
emb = self.decode_latent(codes)
|
|
|
210 |
out = self.decoder(emb)
|
|
|
211 |
out = self.postprocess(out, scale)
|
212 |
-
|
213 |
return out
|
214 |
|
215 |
def decode_latent(self, codes: torch.Tensor):
|
|
|
77 |
"""Set the active number of codebooks used by the quantizer."""
|
78 |
...
|
79 |
|
80 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
class EncodecModel(CompressionModel):
|
|
|
161 |
return x
|
162 |
|
163 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
164 |
+
# B,K,T -> B,C,T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
emb = self.decode_latent(codes)
|
166 |
+
|
167 |
out = self.decoder(emb)
|
168 |
+
|
169 |
out = self.postprocess(out, scale)
|
170 |
+
|
171 |
return out
|
172 |
|
173 |
def decode_latent(self, codes: torch.Tensor):
|
audiocraft/lm.py
CHANGED
@@ -1,769 +1,27 @@
|
|
1 |
-
# ========================= From conditioners.py
|
2 |
-
import soundfile
|
3 |
-
from collections import defaultdict
|
4 |
-
from copy import deepcopy
|
5 |
from dataclasses import dataclass, field
|
6 |
from itertools import chain
|
7 |
import logging
|
8 |
import math
|
9 |
-
from pathlib import Path
|
10 |
-
import random
|
11 |
import re
|
12 |
import typing as tp
|
13 |
-
import warnings
|
14 |
-
import einops
|
15 |
-
from num2words import num2words
|
16 |
-
import spacy
|
17 |
-
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
18 |
import torch
|
19 |
import torch.nn.functional as F
|
20 |
-
from torch.nn.utils.rnn import pad_sequence
|
21 |
from audiocraft.streaming import StreamingModule
|
22 |
-
from audiocraft.transformer import create_sin_embedding
|
23 |
-
from audiocraft.utils.autocast import TorchAutocast
|
24 |
-
from audiocraft.utils.utils import collate, length_to_mask
|
25 |
from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
26 |
from dataclasses import dataclass
|
27 |
from functools import partial
|
28 |
-
import logging
|
29 |
-
import math
|
30 |
-
import typing as tp
|
31 |
-
|
32 |
-
|
33 |
from torch import nn
|
34 |
-
|
35 |
from audiocraft.utils import utils
|
36 |
-
from audiocraft.codebooks_patterns import CodebooksPatternProvider
|
37 |
from audiocraft.activations import get_activation_fn
|
38 |
|
39 |
|
40 |
-
|
41 |
|
42 |
|
43 |
logger = logging.getLogger(__name__)
|
44 |
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
45 |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
46 |
|
47 |
-
|
48 |
-
class WavCondition(tp.NamedTuple):
|
49 |
-
wav: torch.Tensor
|
50 |
-
length: torch.Tensor
|
51 |
-
sample_rate: tp.List[int]
|
52 |
-
path: tp.List[tp.Optional[str]] = []
|
53 |
-
seek_time: tp.List[tp.Optional[float]] = []
|
54 |
-
|
55 |
-
|
56 |
-
class JointEmbedCondition(tp.NamedTuple):
|
57 |
-
wav: torch.Tensor
|
58 |
-
text: tp.List[tp.Optional[str]]
|
59 |
-
length: torch.Tensor
|
60 |
-
sample_rate: tp.List[int]
|
61 |
-
path: tp.List[tp.Optional[str]] = []
|
62 |
-
seek_time: tp.List[tp.Optional[float]] = []
|
63 |
-
|
64 |
-
|
65 |
-
@dataclass
|
66 |
-
class ConditioningAttributes:
|
67 |
-
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
68 |
-
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
69 |
-
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
70 |
-
|
71 |
-
def __getitem__(self, item):
|
72 |
-
return getattr(self, item)
|
73 |
-
|
74 |
-
@property
|
75 |
-
def text_attributes(self):
|
76 |
-
return self.text.keys()
|
77 |
-
|
78 |
-
@property
|
79 |
-
def wav_attributes(self):
|
80 |
-
return self.wav.keys()
|
81 |
-
|
82 |
-
@property
|
83 |
-
def joint_embed_attributes(self):
|
84 |
-
return self.joint_embed.keys()
|
85 |
-
|
86 |
-
@property
|
87 |
-
def attributes(self):
|
88 |
-
return {
|
89 |
-
"text": self.text_attributes,
|
90 |
-
"wav": self.wav_attributes,
|
91 |
-
"joint_embed": self.joint_embed_attributes,
|
92 |
-
}
|
93 |
-
|
94 |
-
def to_flat_dict(self):
|
95 |
-
return {
|
96 |
-
**{f"text.{k}": v for k, v in self.text.items()},
|
97 |
-
**{f"wav.{k}": v for k, v in self.wav.items()},
|
98 |
-
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
|
99 |
-
}
|
100 |
-
|
101 |
-
@classmethod
|
102 |
-
def from_flat_dict(cls, x):
|
103 |
-
out = cls()
|
104 |
-
for k, v in x.items():
|
105 |
-
kind, att = k.split(".")
|
106 |
-
out[kind][att] = v
|
107 |
-
return out
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
def nullify_condition(condition: ConditionType, dim: int = 1):
|
114 |
-
"""Transform an input condition to a null condition.
|
115 |
-
The way it is done by converting it to a single zero vector similarly
|
116 |
-
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
|
117 |
-
|
118 |
-
Args:
|
119 |
-
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
|
120 |
-
dim (int): The dimension that will be truncated (should be the time dimension)
|
121 |
-
WARNING!: dim should not be the batch dimension!
|
122 |
-
Returns:
|
123 |
-
ConditionType: A tuple of null condition and mask
|
124 |
-
"""
|
125 |
-
assert dim != 0, "dim cannot be the batch dimension!"
|
126 |
-
assert isinstance(condition, tuple) and \
|
127 |
-
isinstance(condition[0], torch.Tensor) and \
|
128 |
-
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
|
129 |
-
cond, mask = condition
|
130 |
-
B = cond.shape[0]
|
131 |
-
last_dim = cond.dim() - 1
|
132 |
-
out = cond.transpose(dim, last_dim)
|
133 |
-
out = 0. * out[..., :1]
|
134 |
-
out = out.transpose(dim, last_dim)
|
135 |
-
mask = torch.zeros((B, 1), device=out.device).int()
|
136 |
-
assert cond.dim() == out.dim()
|
137 |
-
return out, mask
|
138 |
-
|
139 |
-
|
140 |
-
def nullify_wav(cond: WavCondition) -> WavCondition:
|
141 |
-
"""Transform a WavCondition to a nullified WavCondition.
|
142 |
-
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
|
143 |
-
|
144 |
-
Args:
|
145 |
-
cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
|
146 |
-
Returns:
|
147 |
-
WavCondition: Nullified wav condition.
|
148 |
-
"""
|
149 |
-
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
|
150 |
-
return WavCondition(
|
151 |
-
wav=null_wav,
|
152 |
-
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
|
153 |
-
sample_rate=cond.sample_rate,
|
154 |
-
path=[None] * cond.wav.shape[0],
|
155 |
-
seek_time=[None] * cond.wav.shape[0],
|
156 |
-
)
|
157 |
-
|
158 |
-
|
159 |
-
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
|
160 |
-
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
|
161 |
-
and replacing metadata by dummy attributes.
|
162 |
-
|
163 |
-
Args:
|
164 |
-
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
|
165 |
-
"""
|
166 |
-
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
|
167 |
-
return JointEmbedCondition(
|
168 |
-
wav=null_wav, text=[None] * len(embed.text),
|
169 |
-
length=torch.LongTensor([0]).to(embed.wav.device),
|
170 |
-
sample_rate=embed.sample_rate,
|
171 |
-
path=[None] * embed.wav.shape[0],
|
172 |
-
seek_time=[0] * embed.wav.shape[0],
|
173 |
-
)
|
174 |
-
|
175 |
-
|
176 |
-
class Tokenizer:
|
177 |
-
"""Base tokenizer implementation
|
178 |
-
(in case we want to introduce more advances tokenizers in the future).
|
179 |
-
"""
|
180 |
-
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
181 |
-
raise NotImplementedError()
|
182 |
-
|
183 |
-
|
184 |
-
class WhiteSpaceTokenizer(Tokenizer):
|
185 |
-
"""This tokenizer should be used for natural language descriptions.
|
186 |
-
For example:
|
187 |
-
["he didn't, know he's going home.", 'shorter sentence'] =>
|
188 |
-
[[78, 62, 31, 4, 78, 25, 19, 34],
|
189 |
-
[59, 77, 0, 0, 0, 0, 0, 0]]
|
190 |
-
"""
|
191 |
-
PUNCTUATION = "?:!.,;"
|
192 |
-
|
193 |
-
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
|
194 |
-
lemma: bool = True, stopwords: bool = True) -> None:
|
195 |
-
self.n_bins = n_bins
|
196 |
-
self.pad_idx = pad_idx
|
197 |
-
self.lemma = lemma
|
198 |
-
self.stopwords = stopwords
|
199 |
-
try:
|
200 |
-
self.nlp = spacy.load(language)
|
201 |
-
except IOError:
|
202 |
-
spacy.cli.download(language) # type: ignore
|
203 |
-
self.nlp = spacy.load(language)
|
204 |
-
|
205 |
-
@tp.no_type_check
|
206 |
-
def __call__(self, texts: tp.List[tp.Optional[str]],
|
207 |
-
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
208 |
-
"""Take a list of strings and convert them to a tensor of indices.
|
209 |
-
|
210 |
-
Args:
|
211 |
-
texts (list[str]): List of strings.
|
212 |
-
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
|
213 |
-
Returns:
|
214 |
-
tuple[torch.Tensor, torch.Tensor]:
|
215 |
-
- Indices of words in the LUT.
|
216 |
-
- And a mask indicating where the padding tokens are
|
217 |
-
"""
|
218 |
-
output, lengths = [], []
|
219 |
-
texts = deepcopy(texts)
|
220 |
-
for i, text in enumerate(texts):
|
221 |
-
# if current sample doesn't have a certain attribute, replace with pad token
|
222 |
-
if text is None:
|
223 |
-
output.append(torch.Tensor([self.pad_idx]))
|
224 |
-
lengths.append(0)
|
225 |
-
continue
|
226 |
-
|
227 |
-
# convert numbers to words
|
228 |
-
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
|
229 |
-
# normalize text
|
230 |
-
text = self.nlp(text) # type: ignore
|
231 |
-
# remove stopwords
|
232 |
-
if self.stopwords:
|
233 |
-
text = [w for w in text if not w.is_stop] # type: ignore
|
234 |
-
# remove punctuation
|
235 |
-
text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
|
236 |
-
# lemmatize if needed
|
237 |
-
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
|
238 |
-
|
239 |
-
texts[i] = " ".join(text)
|
240 |
-
lengths.append(len(text))
|
241 |
-
# convert to tensor
|
242 |
-
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
|
243 |
-
output.append(tokens)
|
244 |
-
|
245 |
-
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
246 |
-
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
|
247 |
-
if return_text:
|
248 |
-
return padded_output, mask, texts # type: ignore
|
249 |
-
return padded_output, mask
|
250 |
-
|
251 |
-
|
252 |
-
class NoopTokenizer(Tokenizer):
|
253 |
-
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
|
254 |
-
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
|
255 |
-
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
|
256 |
-
split it to ["Jeff", "Buckley"] and return an index per word.
|
257 |
-
|
258 |
-
For example:
|
259 |
-
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
|
260 |
-
["Metal", "Rock", "Classical"] => [0, 223, 51]
|
261 |
-
"""
|
262 |
-
def __init__(self, n_bins: int, pad_idx: int = 0):
|
263 |
-
self.n_bins = n_bins
|
264 |
-
self.pad_idx = pad_idx
|
265 |
-
|
266 |
-
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
267 |
-
output, lengths = [], []
|
268 |
-
for text in texts:
|
269 |
-
# if current sample doesn't have a certain attribute, replace with pad token
|
270 |
-
if text is None:
|
271 |
-
output.append(self.pad_idx)
|
272 |
-
lengths.append(0)
|
273 |
-
else:
|
274 |
-
output.append(hash_trick(text, self.n_bins))
|
275 |
-
lengths.append(1)
|
276 |
-
|
277 |
-
tokens = torch.LongTensor(output).unsqueeze(1)
|
278 |
-
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
279 |
-
return tokens, mask
|
280 |
-
|
281 |
-
|
282 |
-
class BaseConditioner(nn.Module):
|
283 |
-
"""Base model for all conditioner modules.
|
284 |
-
We allow the output dim to be different than the hidden dim for two reasons:
|
285 |
-
1) keep our LUTs small when the vocab is large;
|
286 |
-
2) make all condition dims consistent.
|
287 |
-
|
288 |
-
Args:
|
289 |
-
dim (int): Hidden dim of the model.
|
290 |
-
output_dim (int): Output dim of the conditioner.
|
291 |
-
"""
|
292 |
-
def __init__(self, dim: int, output_dim: int):
|
293 |
-
super().__init__()
|
294 |
-
self.dim = dim
|
295 |
-
self.output_dim = output_dim
|
296 |
-
self.output_proj = nn.Linear(dim, output_dim)
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
def forward(self, inputs: tp.Any) -> ConditionType:
|
301 |
-
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
302 |
-
Outputs a ConditionType, after the input data was embedded as a dense vector.
|
303 |
-
|
304 |
-
Returns:
|
305 |
-
ConditionType:
|
306 |
-
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
|
307 |
-
output embedding and D is the dimension of the embedding.
|
308 |
-
- And a mask indicating where the padding tokens.
|
309 |
-
"""
|
310 |
-
raise NotImplementedError()
|
311 |
-
|
312 |
-
|
313 |
-
class TextConditioner(BaseConditioner):
|
314 |
-
...
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
class T5Conditioner(TextConditioner):
|
321 |
-
"""T5-based TextConditioner.
|
322 |
-
|
323 |
-
Args:
|
324 |
-
name (str): Name of the T5 model.
|
325 |
-
output_dim (int): Output dim of the conditioner.
|
326 |
-
finetune (bool): Whether to fine-tune T5 at train time.
|
327 |
-
device (str): Device for T5 Conditioner.
|
328 |
-
autocast_dtype (tp.Optional[str], optional): Autocast dtype.
|
329 |
-
word_dropout (float, optional): Word dropout probability.
|
330 |
-
normalize_text (bool, optional): Whether to apply text normalization.
|
331 |
-
"""
|
332 |
-
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
333 |
-
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
334 |
-
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
335 |
-
MODELS_DIMS = {
|
336 |
-
"t5-small": 512,
|
337 |
-
"t5-base": 768,
|
338 |
-
"t5-large": 1024,
|
339 |
-
"t5-3b": 1024,
|
340 |
-
"t5-11b": 1024,
|
341 |
-
"google/flan-t5-small": 512,
|
342 |
-
"google/flan-t5-base": 768,
|
343 |
-
"google/flan-t5-large": 1024,
|
344 |
-
"google/flan-t5-3b": 1024,
|
345 |
-
"google/flan-t5-11b": 1024,
|
346 |
-
}
|
347 |
-
|
348 |
-
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
|
349 |
-
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
|
350 |
-
normalize_text: bool = False):
|
351 |
-
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
352 |
-
super().__init__(self.MODELS_DIMS[name], output_dim)
|
353 |
-
self.device = device
|
354 |
-
self.name = name
|
355 |
-
self.finetune = finetune
|
356 |
-
self.word_dropout = word_dropout
|
357 |
-
if autocast_dtype is None or self.device == 'cpu':
|
358 |
-
self.autocast = TorchAutocast(enabled=False)
|
359 |
-
if self.device != 'cpu':
|
360 |
-
logger.warning("T5 has no autocast, this might lead to NaN")
|
361 |
-
else:
|
362 |
-
dtype = getattr(torch, autocast_dtype)
|
363 |
-
assert isinstance(dtype, torch.dtype)
|
364 |
-
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
|
365 |
-
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
366 |
-
# Let's disable logging temporarily because T5 will vomit some errors otherwise.
|
367 |
-
# thanks https://gist.github.com/simon-weber/7853144
|
368 |
-
previous_level = logging.root.manager.disable
|
369 |
-
logging.disable(logging.ERROR)
|
370 |
-
with warnings.catch_warnings():
|
371 |
-
warnings.simplefilter("ignore")
|
372 |
-
try:
|
373 |
-
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
374 |
-
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
|
375 |
-
finally:
|
376 |
-
logging.disable(previous_level)
|
377 |
-
if finetune:
|
378 |
-
self.t5 = t5
|
379 |
-
else:
|
380 |
-
# this makes sure that the t5 models is not part
|
381 |
-
# of the saved checkpoint
|
382 |
-
self.__dict__['t5'] = t5.to(device)
|
383 |
-
|
384 |
-
self.normalize_text = normalize_text
|
385 |
-
if normalize_text:
|
386 |
-
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
|
387 |
-
|
388 |
-
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
|
389 |
-
# if current sample doesn't have a certain attribute, replace with empty string
|
390 |
-
entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
|
391 |
-
if self.normalize_text:
|
392 |
-
_, _, entries = self.text_normalizer(entries, return_text=True)
|
393 |
-
if self.word_dropout > 0. and self.training:
|
394 |
-
new_entries = []
|
395 |
-
for entry in entries:
|
396 |
-
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
|
397 |
-
new_entries.append(" ".join(words))
|
398 |
-
entries = new_entries
|
399 |
-
|
400 |
-
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
|
401 |
-
|
402 |
-
inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
|
403 |
-
mask = inputs['attention_mask']
|
404 |
-
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
405 |
-
return inputs
|
406 |
-
|
407 |
-
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
|
408 |
-
mask = inputs['attention_mask']
|
409 |
-
with torch.set_grad_enabled(self.finetune), self.autocast:
|
410 |
-
embeds = self.t5(**inputs).last_hidden_state
|
411 |
-
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
412 |
-
embeds = (embeds * mask.unsqueeze(-1))
|
413 |
-
return embeds, mask
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
class JointEmbeddingConditioner(BaseConditioner):
|
423 |
-
"""Joint embedding conditioning supporting both audio or text conditioning.
|
424 |
-
|
425 |
-
Args:
|
426 |
-
dim (int): Dimension.
|
427 |
-
output_dim (int): Output dimension.
|
428 |
-
device (str): Device.
|
429 |
-
attribute (str): Attribute used by the conditioner.
|
430 |
-
autocast_dtype (str): Autocast for the conditioner.
|
431 |
-
quantize (bool): Whether to quantize the CLAP embedding.
|
432 |
-
n_q (int): Number of residual quantizers (used if quantize is true).
|
433 |
-
bins (int): Quantizers' codebooks size (used if quantize is true).
|
434 |
-
kwargs: Additional parameters for residual vector quantizer.
|
435 |
-
"""
|
436 |
-
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
437 |
-
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
|
438 |
-
n_q: int = 12, bins: int = 1024, **kwargs):
|
439 |
-
super().__init__(dim=dim, output_dim=output_dim)
|
440 |
-
self.device = device
|
441 |
-
self.attribute = attribute
|
442 |
-
if autocast_dtype is None or device == 'cpu':
|
443 |
-
self.autocast = TorchAutocast(enabled=False)
|
444 |
-
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
|
445 |
-
else:
|
446 |
-
dtype = getattr(torch, autocast_dtype)
|
447 |
-
assert isinstance(dtype, torch.dtype)
|
448 |
-
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
|
449 |
-
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
450 |
-
# residual vector quantizer to discretize the conditioned embedding
|
451 |
-
self.quantizer=None
|
452 |
-
if quantize:
|
453 |
-
print('\n\n\n\nWANTS TO QUANTIZE on Inference\n\n\n\n')
|
454 |
-
# self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
|
455 |
-
|
456 |
-
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
457 |
-
"""Get joint embedding in latent space from the inputs.
|
458 |
-
|
459 |
-
Returns:
|
460 |
-
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
|
461 |
-
and corresponding empty indexes.
|
462 |
-
"""
|
463 |
-
raise NotImplementedError()
|
464 |
-
|
465 |
-
def forward(self, x: JointEmbedCondition) -> ConditionType:
|
466 |
-
with self.autocast:
|
467 |
-
embed, empty_idx = self._get_embed(x)
|
468 |
-
if self.quantizer is not None:
|
469 |
-
embed = embed.view(-1, self.dim, 1)
|
470 |
-
q_res = self.quantizer(embed, frame_rate=1)
|
471 |
-
out_embed = q_res.x.view(-1, self.dim)
|
472 |
-
else:
|
473 |
-
out_embed = embed
|
474 |
-
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
|
475 |
-
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
|
476 |
-
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
477 |
-
out_embed = (out_embed * mask.unsqueeze(-1))
|
478 |
-
return out_embed, mask
|
479 |
-
|
480 |
-
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
481 |
-
return x
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
class ConditioningProvider(nn.Module):
|
494 |
-
"""Prepare and provide conditions given all the supported conditioners.
|
495 |
-
|
496 |
-
Args:
|
497 |
-
conditioners (dict): Dictionary of conditioners.
|
498 |
-
device (torch.device or str, optional): Device for conditioners and output condition types.
|
499 |
-
"""
|
500 |
-
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
|
501 |
-
super().__init__()
|
502 |
-
self.device = device
|
503 |
-
self.conditioners = nn.ModuleDict(conditioners)
|
504 |
-
|
505 |
-
@property
|
506 |
-
def joint_embed_conditions(self):
|
507 |
-
return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
|
508 |
-
|
509 |
-
@property
|
510 |
-
def has_joint_embed_conditions(self):
|
511 |
-
return len(self.joint_embed_conditions) > 0
|
512 |
-
|
513 |
-
@property
|
514 |
-
def text_conditions(self):
|
515 |
-
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
516 |
-
|
517 |
-
@property
|
518 |
-
def wav_conditions(self):
|
519 |
-
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
|
520 |
-
|
521 |
-
@property
|
522 |
-
def has_wav_condition(self):
|
523 |
-
return len(self.wav_conditions) > 0
|
524 |
-
|
525 |
-
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
526 |
-
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
527 |
-
The output is for example:
|
528 |
-
{
|
529 |
-
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
530 |
-
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
531 |
-
...
|
532 |
-
}
|
533 |
-
|
534 |
-
Args:
|
535 |
-
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
536 |
-
"""
|
537 |
-
output = {}
|
538 |
-
for attribute, inputs in tokenized.items():
|
539 |
-
condition, mask = self.conditioners[attribute](inputs)
|
540 |
-
output[attribute] = (condition, mask)
|
541 |
-
return output
|
542 |
-
|
543 |
-
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
544 |
-
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
545 |
-
are the attributes and the values are the aggregated input per attribute.
|
546 |
-
For example:
|
547 |
-
Input:
|
548 |
-
[
|
549 |
-
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
550 |
-
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
|
551 |
-
]
|
552 |
-
Output:
|
553 |
-
{
|
554 |
-
"genre": ["Rock", "Hip-hop"],
|
555 |
-
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
556 |
-
}
|
557 |
-
|
558 |
-
Args:
|
559 |
-
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
560 |
-
Returns:
|
561 |
-
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
562 |
-
"""
|
563 |
-
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
564 |
-
texts = [x.text for x in samples]
|
565 |
-
for text in texts:
|
566 |
-
for condition in self.text_conditions:
|
567 |
-
out[condition].append(text[condition])
|
568 |
-
return out
|
569 |
-
|
570 |
-
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
|
571 |
-
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
572 |
-
and the values are Tensors of wavs according to said attributes.
|
573 |
-
|
574 |
-
*Note*: by the time the samples reach this function, each sample should have some waveform
|
575 |
-
inside the "wav" attribute. It should be either:
|
576 |
-
1. A real waveform
|
577 |
-
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
|
578 |
-
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
579 |
-
|
580 |
-
Args:
|
581 |
-
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
582 |
-
Returns:
|
583 |
-
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
|
584 |
-
"""
|
585 |
-
wavs = defaultdict(list)
|
586 |
-
lengths = defaultdict(list)
|
587 |
-
sample_rates = defaultdict(list)
|
588 |
-
paths = defaultdict(list)
|
589 |
-
seek_times = defaultdict(list)
|
590 |
-
out: tp.Dict[str, WavCondition] = {}
|
591 |
-
|
592 |
-
for sample in samples:
|
593 |
-
for attribute in self.wav_conditions:
|
594 |
-
wav, length, sample_rate, path, seek_time = sample.wav[attribute]
|
595 |
-
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
|
596 |
-
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
|
597 |
-
# mono-channel conditioning
|
598 |
-
wav = wav.mean(1, keepdim=True) # [1, 1, T]
|
599 |
-
wavs[attribute].append(wav.flatten()) # [T]
|
600 |
-
lengths[attribute].append(length)
|
601 |
-
sample_rates[attribute].extend(sample_rate)
|
602 |
-
paths[attribute].extend(path)
|
603 |
-
seek_times[attribute].extend(seek_time)
|
604 |
-
|
605 |
-
# stack all wavs to a single tensor
|
606 |
-
for attribute in self.wav_conditions:
|
607 |
-
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
608 |
-
out[attribute] = WavCondition(
|
609 |
-
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
|
610 |
-
paths[attribute], seek_times[attribute])
|
611 |
-
|
612 |
-
return out
|
613 |
-
|
614 |
-
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
|
615 |
-
"""Generate a dict where the keys are attributes by which we compute joint embeddings,
|
616 |
-
and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
|
617 |
-
|
618 |
-
Args:
|
619 |
-
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
|
620 |
-
Returns:
|
621 |
-
A dictionary mapping an attribute name to joint embeddings.
|
622 |
-
"""
|
623 |
-
texts = defaultdict(list)
|
624 |
-
wavs = defaultdict(list)
|
625 |
-
lengths = defaultdict(list)
|
626 |
-
sample_rates = defaultdict(list)
|
627 |
-
paths = defaultdict(list)
|
628 |
-
seek_times = defaultdict(list)
|
629 |
-
channels: int = 0
|
630 |
-
|
631 |
-
out = {}
|
632 |
-
for sample in samples:
|
633 |
-
for attribute in self.joint_embed_conditions:
|
634 |
-
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
|
635 |
-
assert wav.dim() == 3
|
636 |
-
if channels == 0:
|
637 |
-
channels = wav.size(1)
|
638 |
-
else:
|
639 |
-
assert channels == wav.size(1), "not all audio has same number of channels in batch"
|
640 |
-
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
|
641 |
-
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
|
642 |
-
wavs[attribute].append(wav)
|
643 |
-
texts[attribute].extend(text)
|
644 |
-
lengths[attribute].append(length)
|
645 |
-
sample_rates[attribute].extend(sample_rate)
|
646 |
-
paths[attribute].extend(path)
|
647 |
-
seek_times[attribute].extend(seek_time)
|
648 |
-
|
649 |
-
for attribute in self.joint_embed_conditions:
|
650 |
-
stacked_texts = texts[attribute]
|
651 |
-
stacked_paths = paths[attribute]
|
652 |
-
stacked_seek_times = seek_times[attribute]
|
653 |
-
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
|
654 |
-
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
|
655 |
-
stacked_sample_rates = sample_rates[attribute]
|
656 |
-
stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
|
657 |
-
assert stacked_lengths.size(0) == stacked_wavs.size(0)
|
658 |
-
assert len(stacked_sample_rates) == stacked_wavs.size(0)
|
659 |
-
assert len(stacked_texts) == stacked_wavs.size(0)
|
660 |
-
out[attribute] = JointEmbedCondition(
|
661 |
-
text=stacked_texts, wav=stacked_wavs,
|
662 |
-
length=stacked_lengths, sample_rate=stacked_sample_rates,
|
663 |
-
path=stacked_paths, seek_time=stacked_seek_times)
|
664 |
-
|
665 |
-
return out
|
666 |
-
|
667 |
-
|
668 |
-
class ConditionFuser(StreamingModule):
|
669 |
-
"""Condition fuser handles the logic to combine the different conditions
|
670 |
-
to the actual model input.
|
671 |
-
|
672 |
-
Args:
|
673 |
-
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
|
674 |
-
each condition. For example:
|
675 |
-
{
|
676 |
-
"prepend": ["description"],
|
677 |
-
"sum": ["genre", "bpm"],
|
678 |
-
"cross": ["description"],
|
679 |
-
}
|
680 |
-
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
681 |
-
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
682 |
-
"""
|
683 |
-
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
684 |
-
|
685 |
-
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
686 |
-
cross_attention_pos_emb_scale: float = 1.0):
|
687 |
-
super().__init__()
|
688 |
-
assert all(
|
689 |
-
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
690 |
-
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
|
691 |
-
self.cross_attention_pos_emb = cross_attention_pos_emb
|
692 |
-
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
693 |
-
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
694 |
-
self.cond2fuse: tp.Dict[str, str] = {}
|
695 |
-
for fuse_method, conditions in fuse2cond.items():
|
696 |
-
for condition in conditions:
|
697 |
-
self.cond2fuse[condition] = fuse_method
|
698 |
-
|
699 |
-
def forward(
|
700 |
-
self,
|
701 |
-
input: torch.Tensor,
|
702 |
-
conditions: tp.Dict[str, ConditionType]
|
703 |
-
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
704 |
-
"""Fuse the conditions to the provided model input.
|
705 |
-
|
706 |
-
Args:
|
707 |
-
input (torch.Tensor): Transformer input.
|
708 |
-
conditions (dict[str, ConditionType]): Dict of conditions.
|
709 |
-
Returns:
|
710 |
-
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
|
711 |
-
after the conditions have been fused. The second output tensor is the tensor
|
712 |
-
used for cross-attention or None if no cross attention inputs exist.
|
713 |
-
"""
|
714 |
-
B, T, _ = input.shape
|
715 |
-
|
716 |
-
if 'offsets' in self._streaming_state:
|
717 |
-
first_step = False
|
718 |
-
offsets = self._streaming_state['offsets']
|
719 |
-
else:
|
720 |
-
first_step = True
|
721 |
-
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
722 |
-
|
723 |
-
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
724 |
-
f"given conditions contain unknown attributes for fuser, " \
|
725 |
-
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
|
726 |
-
cross_attention_output = None
|
727 |
-
for cond_type, (cond, cond_mask) in conditions.items():
|
728 |
-
op = self.cond2fuse[cond_type]
|
729 |
-
if op == 'sum':
|
730 |
-
input += cond
|
731 |
-
elif op == 'input_interpolate':
|
732 |
-
cond = einops.rearrange(cond, "b t d -> b d t")
|
733 |
-
cond = F.interpolate(cond, size=input.shape[1])
|
734 |
-
input += einops.rearrange(cond, "b d t -> b t d")
|
735 |
-
elif op == 'prepend':
|
736 |
-
if first_step:
|
737 |
-
input = torch.cat([cond, input], dim=1)
|
738 |
-
elif op == 'cross':
|
739 |
-
if cross_attention_output is not None:
|
740 |
-
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
741 |
-
else:
|
742 |
-
cross_attention_output = cond
|
743 |
-
else:
|
744 |
-
raise ValueError(f"unknown op ({op})")
|
745 |
-
|
746 |
-
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
747 |
-
print('SIN EMBED')
|
748 |
-
positions = torch.arange(
|
749 |
-
cross_attention_output.shape[1],
|
750 |
-
device=cross_attention_output.device
|
751 |
-
).view(1, -1, 1)
|
752 |
-
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
|
753 |
-
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
|
754 |
-
|
755 |
-
if self._is_streaming:
|
756 |
-
self._streaming_state['offsets'] = offsets + T
|
757 |
-
|
758 |
-
return input, cross_attention_output
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
# ============================================== From LM.py
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
logger = logging.getLogger(__name__)
|
767 |
ConditionTensors = tp.Dict[str, ConditionType]
|
768 |
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
769 |
|
@@ -876,8 +134,11 @@ class LMModel(StreamingModule):
|
|
876 |
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
877 |
**kwargs: Additional parameters for the transformer encoder.
|
878 |
"""
|
879 |
-
def __init__(self,
|
880 |
-
|
|
|
|
|
|
|
881 |
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
|
882 |
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
|
883 |
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
@@ -952,27 +213,11 @@ class LMModel(StreamingModule):
|
|
952 |
def num_codebooks(self) -> int:
|
953 |
return self.n_q
|
954 |
|
955 |
-
def forward(self,
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
961 |
-
S the sequence steps, return the logits with shape [B, card, K, S].
|
962 |
-
|
963 |
-
Args:
|
964 |
-
indices (torch.Tensor): Indices of the codes to model.
|
965 |
-
conditions (list of ConditioningAttributes): Conditions to use when modeling
|
966 |
-
the given codes. Note that when evaluating multiple time with the same conditioning
|
967 |
-
you should pre-compute those and pass them as `condition_tensors`.
|
968 |
-
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
969 |
-
tensors, see `conditions`.
|
970 |
-
stage (int): The codebook level that is being predicted. Relevant for MAGNeT
|
971 |
-
in which prediction is done in a codebook-by-codebook manner.
|
972 |
-
Takes values in range(n_q), and ignored by default.
|
973 |
-
Returns:
|
974 |
-
torch.Tensor: Logits.
|
975 |
-
"""
|
976 |
B, K, S = sequence.shape
|
977 |
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
|
978 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
@@ -983,8 +228,8 @@ class LMModel(StreamingModule):
|
|
983 |
condition_tensors = self.condition_provider(tokenized)
|
984 |
else:
|
985 |
assert not conditions, "Shouldn't pass both conditions and condition_tensors."
|
986 |
-
|
987 |
-
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
988 |
|
989 |
out = self.transformer(input_, cross_attention_src=cross_attention_input,
|
990 |
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
|
@@ -999,60 +244,6 @@ class LMModel(StreamingModule):
|
|
999 |
|
1000 |
return logits # [B, K, S, card]
|
1001 |
|
1002 |
-
def compute_predictions(
|
1003 |
-
self, codes: torch.Tensor,
|
1004 |
-
conditions: tp.List[ConditioningAttributes],
|
1005 |
-
condition_tensors: tp.Optional[ConditionTensors] = None,
|
1006 |
-
stage: int = -1,
|
1007 |
-
keep_only_valid_steps: bool = True) -> LMOutput:
|
1008 |
-
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
1009 |
-
forward using the specified codes interleaving pattern.
|
1010 |
-
|
1011 |
-
Args:
|
1012 |
-
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
1013 |
-
K the number of codebooks and T the number of timesteps.
|
1014 |
-
conditions (list of ConditioningAttributes): conditionings to use when modeling
|
1015 |
-
the given codes. Note that when evaluating multiple time with the same conditioning
|
1016 |
-
you should pre-compute those and pass them as `condition_tensors`.
|
1017 |
-
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
|
1018 |
-
tensors, see `conditions`.
|
1019 |
-
stage (int): The codebook level that is being predicted. Relevant for MAGNeT
|
1020 |
-
in which prediction is done in a codebook-by-codebook manner.
|
1021 |
-
Takes values in range(n_q), and ignored by default.
|
1022 |
-
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
1023 |
-
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
1024 |
-
Returns:
|
1025 |
-
LMOutput: Language model outputs
|
1026 |
-
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
1027 |
-
i.e. the first item corresponds to logits to predict the first code, meaning that
|
1028 |
-
no additional shifting of codes and logits is required.
|
1029 |
-
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
|
1030 |
-
Given the specified interleaving strategies, parts of the logits and codes should
|
1031 |
-
not be considered as valid predictions because of invalid context.
|
1032 |
-
"""
|
1033 |
-
B, K, T = codes.shape
|
1034 |
-
codes = codes.contiguous()
|
1035 |
-
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
1036 |
-
# what is the T is it 2048 ?
|
1037 |
-
# and then what is pattern -> another function?
|
1038 |
-
pattern = self.pattern_provider.get_pattern(T)
|
1039 |
-
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
1040 |
-
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
|
1041 |
-
)
|
1042 |
-
|
1043 |
-
# apply model on pattern sequence
|
1044 |
-
model = self if self._fsdp is None else self._fsdp
|
1045 |
-
logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card]
|
1046 |
-
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
1047 |
-
# and provide the corresponding mask over invalid positions of tokens
|
1048 |
-
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
1049 |
-
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
1050 |
-
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
1051 |
-
logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps
|
1052 |
-
)
|
1053 |
-
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
1054 |
-
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
1055 |
-
return LMOutput(logits, logits_mask)
|
1056 |
|
1057 |
def _sample_next_token(self,
|
1058 |
sequence,
|
@@ -1127,11 +318,12 @@ class LMModel(StreamingModule):
|
|
1127 |
|
1128 |
return next_token
|
1129 |
|
|
|
1130 |
@torch.no_grad()
|
1131 |
def generate(self,
|
1132 |
-
prompt
|
1133 |
-
conditions
|
1134 |
-
num_samples
|
1135 |
max_gen_len: int = 256,
|
1136 |
use_sampling: bool = True,
|
1137 |
temp: float = 1.0,
|
@@ -1143,25 +335,12 @@ class LMModel(StreamingModule):
|
|
1143 |
check: bool = False,
|
1144 |
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
1145 |
**kwargs) -> torch.Tensor:
|
1146 |
-
"""
|
1147 |
-
be performed in a greedy fashion or using sampling with top K and top P strategies.
|
1148 |
|
1149 |
Args:
|
1150 |
-
|
1151 |
-
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
|
1152 |
-
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
1153 |
-
max_gen_len (int): Maximum generation length.
|
1154 |
-
use_sampling (bool): Whether to use a sampling strategy or not.
|
1155 |
-
temp (float): Sampling temperature.
|
1156 |
-
top_k (int): K for "top-k" sampling.
|
1157 |
-
top_p (float): P for "top-p" sampling.
|
1158 |
-
cfg_coeff (float, optional): Classifier-free guidance coefficient.
|
1159 |
-
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
|
1160 |
-
remove_prompts (bool): Whether to remove prompts from generation or not.
|
1161 |
-
check (bool): Whether to apply further checks on generated sequence.
|
1162 |
-
callback (Callback, optional): Callback function to report generation progress.
|
1163 |
Returns:
|
1164 |
-
torch.Tensor:
|
1165 |
"""
|
1166 |
assert not self.training, "generation shouldn't be used in training mode."
|
1167 |
first_param = next(iter(self.parameters()))
|
@@ -1190,20 +369,13 @@ class LMModel(StreamingModule):
|
|
1190 |
# the padding structure is exactly the same between train and test.
|
1191 |
# With a batch size of 1, this can be slower though.
|
1192 |
cfg_conditions: CFGConditions
|
1193 |
-
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
-
|
1200 |
-
)
|
1201 |
-
else:
|
1202 |
-
conditions = conditions + null_conditions
|
1203 |
-
tokenized = self.condition_provider.tokenize(conditions)
|
1204 |
-
cfg_conditions = self.condition_provider(tokenized)
|
1205 |
-
else:
|
1206 |
-
cfg_conditions = {}
|
1207 |
|
1208 |
if prompt is None:
|
1209 |
assert num_samples > 0
|
@@ -1222,18 +394,26 @@ class LMModel(StreamingModule):
|
|
1222 |
|
1223 |
gen_codes[..., :start_offset] = prompt
|
1224 |
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
1225 |
-
gen_sequence,
|
1226 |
-
|
1227 |
-
# it is the first sequence step that contains the `start_offset` timestep
|
1228 |
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
|
|
1229 |
assert start_offset_sequence is not None
|
1230 |
|
1231 |
with self.streaming():
|
1232 |
unconditional_state = self.get_streaming_state()
|
1233 |
prev_offset = 0
|
1234 |
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1235 |
for offset in range(start_offset_sequence, gen_sequence_len):
|
1236 |
# get current sequence (note that the streaming API is providing the caching over previous offsets)
|
|
|
1237 |
curr_sequence = gen_sequence[..., prev_offset:offset]
|
1238 |
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
1239 |
if check:
|
@@ -1268,11 +448,13 @@ class LMModel(StreamingModule):
|
|
1268 |
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
1269 |
unconditional_state.clear()
|
1270 |
|
1271 |
-
out_codes,
|
1272 |
|
1273 |
out_start_offset = start_offset if remove_prompts else 0
|
1274 |
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
1275 |
|
1276 |
# ensure the returned codes are all valid
|
|
|
1277 |
# assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
|
|
1278 |
return out_codes
|
|
|
|
|
|
|
|
|
|
|
1 |
from dataclasses import dataclass, field
|
2 |
from itertools import chain
|
3 |
import logging
|
4 |
import math
|
|
|
|
|
5 |
import re
|
6 |
import typing as tp
|
|
|
|
|
|
|
|
|
|
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
|
|
9 |
from audiocraft.streaming import StreamingModule
|
|
|
|
|
|
|
10 |
from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
11 |
from dataclasses import dataclass
|
12 |
from functools import partial
|
|
|
|
|
|
|
|
|
|
|
13 |
from torch import nn
|
|
|
14 |
from audiocraft.utils import utils
|
|
|
15 |
from audiocraft.activations import get_activation_fn
|
16 |
|
17 |
|
18 |
+
# ============================================== From LM.py
|
19 |
|
20 |
|
21 |
logger = logging.getLogger(__name__)
|
22 |
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
23 |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
ConditionTensors = tp.Dict[str, ConditionType]
|
26 |
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
27 |
|
|
|
134 |
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
135 |
**kwargs: Additional parameters for the transformer encoder.
|
136 |
"""
|
137 |
+
def __init__(self,
|
138 |
+
pattern_provider,
|
139 |
+
condition_provider,
|
140 |
+
fuser,
|
141 |
+
n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
|
142 |
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
|
143 |
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
|
144 |
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
|
|
213 |
def num_codebooks(self) -> int:
|
214 |
return self.n_q
|
215 |
|
216 |
+
def forward(self,
|
217 |
+
sequence,
|
218 |
+
conditions,
|
219 |
+
condition_tensors=None,
|
220 |
+
stage = -1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
B, K, S = sequence.shape
|
222 |
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
|
223 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
|
|
228 |
condition_tensors = self.condition_provider(tokenized)
|
229 |
else:
|
230 |
assert not conditions, "Shouldn't pass both conditions and condition_tensors."
|
231 |
+
|
232 |
+
input_, cross_attention_input = self.fuser(input_, condition_tensors) # DEFINE conditioners.py
|
233 |
|
234 |
out = self.transformer(input_, cross_attention_src=cross_attention_input,
|
235 |
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
|
|
|
244 |
|
245 |
return logits # [B, K, S, card]
|
246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
def _sample_next_token(self,
|
249 |
sequence,
|
|
|
318 |
|
319 |
return next_token
|
320 |
|
321 |
+
# GENERATE class revert_codebook_patterns()
|
322 |
@torch.no_grad()
|
323 |
def generate(self,
|
324 |
+
prompt = None,
|
325 |
+
conditions = [],
|
326 |
+
num_samples = None,
|
327 |
max_gen_len: int = 256,
|
328 |
use_sampling: bool = True,
|
329 |
temp: float = 1.0,
|
|
|
335 |
check: bool = False,
|
336 |
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
337 |
**kwargs) -> torch.Tensor:
|
338 |
+
"""Default generation takes random token of top_250 logits
|
|
|
339 |
|
340 |
Args:
|
341 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
Returns:
|
343 |
+
torch.Tensor: tokens
|
344 |
"""
|
345 |
assert not self.training, "generation shouldn't be used in training mode."
|
346 |
first_param = next(iter(self.parameters()))
|
|
|
369 |
# the padding structure is exactly the same between train and test.
|
370 |
# With a batch size of 1, this can be slower though.
|
371 |
cfg_conditions: CFGConditions
|
372 |
+
# two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
373 |
+
|
374 |
+
null_conditions = conditions
|
375 |
+
conditions = conditions + null_conditions
|
376 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
377 |
+
cfg_conditions = self.condition_provider(tokenized)
|
378 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
if prompt is None:
|
381 |
assert num_samples > 0
|
|
|
394 |
|
395 |
gen_codes[..., :start_offset] = prompt
|
396 |
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
397 |
+
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
398 |
+
|
|
|
399 |
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
400 |
+
# print('\n=', start_offset_sequence, '\n=') # 1
|
401 |
assert start_offset_sequence is not None
|
402 |
|
403 |
with self.streaming():
|
404 |
unconditional_state = self.get_streaming_state()
|
405 |
prev_offset = 0
|
406 |
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
|
407 |
+
|
408 |
+
# --
|
409 |
+
# print(mask.shape, mask.sum(), 'MSK LM')
|
410 |
+
# torch.Size([4, 39]) tensor(140, device='cuda:0') MSK LM ? Fully 1 normal no special token
|
411 |
+
# --
|
412 |
+
|
413 |
+
|
414 |
for offset in range(start_offset_sequence, gen_sequence_len):
|
415 |
# get current sequence (note that the streaming API is providing the caching over previous offsets)
|
416 |
+
|
417 |
curr_sequence = gen_sequence[..., prev_offset:offset]
|
418 |
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
419 |
if check:
|
|
|
448 |
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
449 |
unconditional_state.clear()
|
450 |
|
451 |
+
out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
452 |
|
453 |
out_start_offset = start_offset if remove_prompts else 0
|
454 |
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
455 |
|
456 |
# ensure the returned codes are all valid
|
457 |
+
|
458 |
# assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
459 |
+
|
460 |
return out_codes
|
audiocraft/loaders.py
CHANGED
@@ -101,7 +101,8 @@ def _delete_param(cfg: DictConfig, full_name: str):
|
|
101 |
OmegaConf.set_struct(cfg, True)
|
102 |
|
103 |
|
104 |
-
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
|
|
|
105 |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
106 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
107 |
cfg.device = str(device)
|
|
|
101 |
OmegaConf.set_struct(cfg, True)
|
102 |
|
103 |
|
104 |
+
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
|
105 |
+
cache_dir: tp.Optional[str] = None):
|
106 |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
107 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
108 |
cfg.device = str(device)
|
demo.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
from audiocraft.audiogen import AudioGen #, audio_write
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
|
5 |
print('\n\n\n\n___________________')
|
6 |
|
7 |
-
txt = '
|
8 |
|
9 |
sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
|
10 |
-
sound_generator.set_generation_params(duration=
|
11 |
|
12 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
x /= np.abs(x).max() + 1e-7
|
14 |
|
15 |
-
audiofile.write('
|
|
|
1 |
from audiocraft.audiogen import AudioGen #, audio_write
|
2 |
+
|
|
|
3 |
|
4 |
print('\n\n\n\n___________________')
|
5 |
|
6 |
+
txt = 'austrian music'
|
7 |
|
8 |
sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
|
9 |
+
sound_generator.set_generation_params(duration=4.7) # why is generating so long at 14 seconds
|
10 |
|
11 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
12 |
x /= np.abs(x).max() + 1e-7
|
13 |
|
14 |
+
audiofile.write('del_seane.wav', x, 16000)
|