Dionyssos commited on
Commit
e366cd5
·
1 Parent(s): 0a8807e

clean unused functions

Browse files
audiocraft/builders.py CHANGED
@@ -10,22 +10,13 @@ from the Hydra config.
10
  """
11
 
12
  import typing as tp
13
-
14
- import audiocraft
15
  import omegaconf
16
  import torch
17
 
18
  from .encodec import CompressionModel, EncodecModel
19
  from .lm import LMModel
20
  from .seanet import SEANetDecoder
21
- from .codebooks_patterns import (
22
- CodebooksPatternProvider,
23
- DelayedPatternProvider,
24
- MusicLMPattern,
25
- ParallelPatternProvider,
26
- UnrolledPatternProvider,
27
- CoarseFirstPattern,
28
- )
29
  from .conditioners import (
30
  BaseConditioner,
31
  ConditionFuser,
@@ -159,45 +150,18 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
159
  return fuser
160
 
161
 
162
- def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
163
- """Instantiate a codebooks pattern provider object."""
164
  pattern_providers = {
165
- 'parallel': ParallelPatternProvider,
166
- 'delay': DelayedPatternProvider,
167
- 'unroll': UnrolledPatternProvider,
168
- 'coarse_first': CoarseFirstPattern,
169
- 'musiclm': MusicLMPattern,
170
  }
171
  name = cfg.modeling
172
  kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
 
173
  klass = pattern_providers[name]
174
  return klass(n_q, **kwargs)
175
 
176
 
177
- def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
178
- """Instantiate a debug compression model to be used for unit tests."""
179
- assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
180
- model_ratios = {
181
- 16000: [10, 8, 8], # 25 Hz at 16kHz
182
- 32000: [10, 8, 16] # 25 Hz at 32kHz
183
- }
184
- ratios: tp.List[int] = model_ratios[sample_rate]
185
- frame_rate = 25
186
- seanet_kwargs: dict = {
187
- 'n_filters': 4,
188
- 'n_residual_layers': 1,
189
- 'dimension': 32,
190
- 'ratios': ratios,
191
- }
192
- encoder = SEANetEncoder(**seanet_kwargs)
193
- decoder = SEANetDecoder(**seanet_kwargs)
194
- quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
195
- init_x = torch.randn(8, 32, 128)
196
- quantizer(init_x, 1) # initialize kmeans etc.
197
- compression_model = EncodecModel(
198
- encoder, decoder, quantizer,
199
- frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
200
- return compression_model.eval()
201
 
202
 
203
  def get_diffusion_model(cfg: omegaconf.DictConfig):
 
10
  """
11
 
12
  import typing as tp
 
 
13
  import omegaconf
14
  import torch
15
 
16
  from .encodec import CompressionModel, EncodecModel
17
  from .lm import LMModel
18
  from .seanet import SEANetDecoder
19
+ from .codebooks_patterns import DelayedPatternProvider
 
 
 
 
 
 
 
20
  from .conditioners import (
21
  BaseConditioner,
22
  ConditionFuser,
 
150
  return fuser
151
 
152
 
153
+ def get_codebooks_pattern_provider(n_q, cfg):
 
154
  pattern_providers = {
155
+ 'delay': DelayedPatternProvider, # THIS
 
 
 
 
156
  }
157
  name = cfg.modeling
158
  kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
159
+
160
  klass = pattern_providers[name]
161
  return klass(n_q, **kwargs)
162
 
163
 
164
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  def get_diffusion_model(cfg: omegaconf.DictConfig):
audiocraft/codebooks_patterns.py CHANGED
@@ -52,7 +52,7 @@ class Pattern:
52
  self._validate_layout()
53
  self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
54
  self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
55
- logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
56
 
57
  def _validate_layout(self):
58
  """Runs checks on the layout to ensure a valid pattern is defined.
@@ -356,193 +356,4 @@ class DelayedPatternProvider(CodebooksPatternProvider):
356
  return Pattern(out, n_q=self.n_q, timesteps=timesteps)
357
 
358
 
359
- class ParallelPatternProvider(DelayedPatternProvider):
360
- """Provider for parallel pattern across codebooks.
361
- This pattern provider is a special case of the delayed pattern with actually no delay,
362
- hence delays=repeat(0, n_q).
363
 
364
- Args:
365
- n_q (int): Number of codebooks.
366
- empty_initial (int): Prepend with N empty list of coordinates.
367
- """
368
- def __init__(self, n_q: int, empty_initial: int = 0):
369
- super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
370
-
371
-
372
- class UnrolledPatternProvider(CodebooksPatternProvider):
373
- """Provider for unrolling codebooks pattern.
374
- This pattern provider enables to represent the codebook flattened completely or only to some extend
375
- while also specifying a given delay between the flattened codebooks representation, allowing to
376
- unroll the codebooks in the sequence.
377
-
378
- Example:
379
- 1. Flattening of the codebooks.
380
- By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
381
- taking n_q = 3 and timesteps = 4:
382
- [[1, 2, 3, 4],
383
- [1, 2, 3, 4],
384
- [1, 2, 3, 4]]
385
- will result into:
386
- [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
387
- [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
388
- [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
389
- 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
390
- for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
391
- taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
392
- [[1, 2, 3, 4],
393
- [1, 2, 3, 4],
394
- [1, 2, 3, 4]]
395
- will result into:
396
- [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
397
- [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
398
- [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
399
- 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
400
- allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
401
- same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
402
- and delays = [0, 3, 3]:
403
- [[1, 2, 3, 4],
404
- [1, 2, 3, 4],
405
- [1, 2, 3, 4]]
406
- will result into:
407
- [[S, S, S, 1, S, 2, S, 3, S, 4],
408
- [S, S, S, 1, S, 2, S, 3, S, 4],
409
- [1, 2, 3, S, 4, S, 5, S, 6, S]]
410
-
411
- Args:
412
- n_q (int): Number of codebooks.
413
- flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
414
- the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
415
- have n_q extra steps for each timestep.
416
- delays (list of int, optional): Delay for each of the codebooks. If not defined,
417
- no delay is added and therefore will default to [0] * ``n_q``.
418
- Note that two codebooks that will be flattened to the same inner step
419
- should have the same delay, otherwise the pattern is considered as invalid.
420
- """
421
- FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
422
-
423
- def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
424
- delays: tp.Optional[tp.List[int]] = None):
425
- super().__init__(n_q)
426
- if flattening is None:
427
- flattening = list(range(n_q))
428
- if delays is None:
429
- delays = [0] * n_q
430
- assert len(flattening) == n_q
431
- assert len(delays) == n_q
432
- assert sorted(flattening) == flattening
433
- assert sorted(delays) == delays
434
- self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
435
- self.max_delay = max(delays)
436
-
437
- def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
438
- """Build a flattened codebooks representation as a dictionary of inner step
439
- and the actual codebook indices corresponding to the flattened codebook. For convenience, we
440
- also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
441
- """
442
- flattened_codebooks: dict = {}
443
- for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
444
- if inner_step not in flattened_codebooks:
445
- flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
446
- else:
447
- flat_codebook = flattened_codebooks[inner_step]
448
- assert flat_codebook.delay == delay, (
449
- "Delay and flattening between codebooks is inconsistent: ",
450
- "two codebooks flattened to the same position should have the same delay."
451
- )
452
- flat_codebook.codebooks.append(q)
453
- flattened_codebooks[inner_step] = flat_codebook
454
- return flattened_codebooks
455
-
456
- @property
457
- def _num_inner_steps(self):
458
- """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
459
- """
460
- return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
461
-
462
- def num_virtual_steps(self, timesteps: int) -> int:
463
- return timesteps * self._num_inner_steps + 1
464
-
465
- def get_pattern(self, timesteps: int) -> Pattern:
466
- """Builds pattern for delay across codebooks.
467
-
468
- Args:
469
- timesteps (int): Total number of timesteps.
470
- """
471
- # the PatternLayout is built as a tuple of sequence position and list of coordinates
472
- # so that it can be reordered properly given the required delay between codebooks of given timesteps
473
- indexed_out: list = [(-1, [])]
474
- max_timesteps = timesteps + self.max_delay
475
- for t in range(max_timesteps):
476
- # for each timestep, we unroll the flattened codebooks,
477
- # emitting the sequence step with the corresponding delay
478
- for step in range(self._num_inner_steps):
479
- if step in self._flattened_codebooks:
480
- # we have codebooks at this virtual step to emit
481
- step_codebooks = self._flattened_codebooks[step]
482
- t_for_q = t + step_codebooks.delay
483
- coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
484
- if t_for_q < max_timesteps and t < max_timesteps:
485
- indexed_out.append((t_for_q, coords))
486
- else:
487
- # there is no codebook in this virtual step so we emit an empty list
488
- indexed_out.append((t, []))
489
- out = [coords for _, coords in sorted(indexed_out)]
490
- return Pattern(out, n_q=self.n_q, timesteps=timesteps)
491
-
492
-
493
- class CoarseFirstPattern(CodebooksPatternProvider):
494
- """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
495
- potentially with delays.
496
-
497
- ..Warning:: You must always generate the full training duration at test time, for instance,
498
- 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
499
- location. This is due to the non causality of the remaining codebooks with respect to
500
- the first ones.
501
-
502
- Args:
503
- n_q (int): Number of codebooks.
504
- delays (list of int, optional): Delay for each of the codebooks.
505
- If delays not defined, each codebook is delayed by 1 compared to the previous one.
506
- """
507
- def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
508
- super().__init__(n_q)
509
- if delays is None:
510
- delays = [0] * (n_q - 1)
511
- self.delays = delays
512
- assert len(self.delays) == self.n_q - 1
513
- assert sorted(self.delays) == self.delays
514
-
515
- def get_pattern(self, timesteps: int) -> Pattern:
516
- out: PatternLayout = [[]]
517
- for t in range(timesteps):
518
- out.append([LayoutCoord(t, 0)])
519
- max_delay = max(self.delays)
520
- for t in range(timesteps + max_delay):
521
- v = []
522
- for q, delay in enumerate(self.delays):
523
- t_for_q = t - delay
524
- if t_for_q >= 0:
525
- v.append(LayoutCoord(t_for_q, q + 1))
526
- out.append(v)
527
- return Pattern(out, n_q=self.n_q, timesteps=timesteps)
528
-
529
-
530
- class MusicLMPattern(CodebooksPatternProvider):
531
- """Almost MusicLM style pattern. This is equivalent to full flattening
532
- but in a different order.
533
-
534
- Args:
535
- n_q (int): Number of codebooks.
536
- group_by (int): Number of codebooks to group together.
537
- """
538
- def __init__(self, n_q: int, group_by: int = 2):
539
- super().__init__(n_q)
540
- self.group_by = group_by
541
-
542
- def get_pattern(self, timesteps: int) -> Pattern:
543
- out: PatternLayout = [[]]
544
- for offset in range(0, self.n_q, self.group_by):
545
- for t in range(timesteps):
546
- for q in range(offset, offset + self.group_by):
547
- out.append([LayoutCoord(t, q)])
548
- return Pattern(out, n_q=self.n_q, timesteps=timesteps)
 
52
  self._validate_layout()
53
  self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
54
  self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
55
+ print("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
56
 
57
  def _validate_layout(self):
58
  """Runs checks on the layout to ensure a valid pattern is defined.
 
356
  return Pattern(out, n_q=self.n_q, timesteps=timesteps)
357
 
358
 
 
 
 
 
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/conditioners.py CHANGED
@@ -410,7 +410,10 @@ class ConditionFuser(StreamingModule):
410
  # print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'}
411
 
412
  cross_attention_output = cond
413
-
 
 
 
414
 
415
  if self._is_streaming:
416
  self._streaming_state['offsets'] = offsets + T
 
410
  # print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'}
411
 
412
  cross_attention_output = cond
413
+ # print(f'{cross_attention_output.shape=} for {input.sum()=}')
414
+ # cross_attention_output.shape=torch.Size([2, 5, 1536]) for input.sum()=tensor(-0.0650, device='cuda:0')
415
+ # cross_attention_output.shape=torch.Size([2, 5, 1536]) for input.sum()=tensor(3.7672, device='cuda:0')
416
+
417
 
418
  if self._is_streaming:
419
  self._streaming_state['offsets'] = offsets + T
audiocraft/encodec.py CHANGED
@@ -77,42 +77,7 @@ class CompressionModel(ABC, nn.Module):
77
  """Set the active number of codebooks used by the quantizer."""
78
  ...
79
 
80
- @staticmethod
81
- def get_pretrained(
82
- name: str, device: tp.Union[torch.device, str] = 'cpu'
83
- ) -> 'CompressionModel':
84
- """Instantiate a CompressionModel from a given pretrained model.
85
-
86
- Args:
87
- name (Path or str): name of the pretrained model. See after.
88
- device (torch.device or str): Device on which the model is loaded.
89
-
90
- Pretrained models:
91
- - dac_44khz (https://github.com/descriptinc/descript-audio-codec)
92
- - dac_24khz (same)
93
- - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
94
- - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
95
- - your own model on Hugging Face. Export instructions to come...
96
- """
97
-
98
- from . import builders, loaders
99
- model: CompressionModel
100
- if name in ['dac_44khz', 'dac_24khz']:
101
- model_type = name.split('_')[1]
102
- logger.info("Getting pretrained compression model from DAC %s", model_type)
103
- model = DAC(model_type)
104
- elif name in ['debug_compression_model']:
105
- logger.info("Getting pretrained compression model for debug")
106
- model = builders.get_debug_compression_model()
107
- elif Path(name).exists():
108
- # We assume here if the path exists that it is in fact an AC checkpoint
109
- # that was exported using `audiocraft.utils.export` functions.
110
- model = loaders.load_compression_model(name, device=device)
111
- else:
112
- logger.info("Getting pretrained compression model from HF %s", name)
113
- hf_model = HFEncodecModel.from_pretrained(name)
114
- model = HFEncodecCompressionModel(hf_model).to(device)
115
- return model.to(device).eval()
116
 
117
 
118
  class EncodecModel(CompressionModel):
@@ -196,20 +161,13 @@ class EncodecModel(CompressionModel):
196
  return x
197
 
198
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
199
- """Decode the given codes to a reconstructed representation, using the scale to perform
200
- audio denormalization if needed.
201
-
202
- Args:
203
- codes (torch.Tensor): Int tensor of shape [B, K, T]
204
- scale (torch.Tensor, optional): Float tensor containing the scale value.
205
-
206
- Returns:
207
- out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
208
- """
209
  emb = self.decode_latent(codes)
 
210
  out = self.decoder(emb)
 
211
  out = self.postprocess(out, scale)
212
- # out contains extra padding added by the encoder and decoder
213
  return out
214
 
215
  def decode_latent(self, codes: torch.Tensor):
 
77
  """Set the active number of codebooks used by the quantizer."""
78
  ...
79
 
80
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  class EncodecModel(CompressionModel):
 
161
  return x
162
 
163
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
164
+ # B,K,T -> B,C,T
 
 
 
 
 
 
 
 
 
165
  emb = self.decode_latent(codes)
166
+
167
  out = self.decoder(emb)
168
+
169
  out = self.postprocess(out, scale)
170
+
171
  return out
172
 
173
  def decode_latent(self, codes: torch.Tensor):
audiocraft/lm.py CHANGED
@@ -1,769 +1,27 @@
1
- # ========================= From conditioners.py
2
- import soundfile
3
- from collections import defaultdict
4
- from copy import deepcopy
5
  from dataclasses import dataclass, field
6
  from itertools import chain
7
  import logging
8
  import math
9
- from pathlib import Path
10
- import random
11
  import re
12
  import typing as tp
13
- import warnings
14
- import einops
15
- from num2words import num2words
16
- import spacy
17
- from transformers import T5EncoderModel, T5Tokenizer # type: ignore
18
  import torch
19
  import torch.nn.functional as F
20
- from torch.nn.utils.rnn import pad_sequence
21
  from audiocraft.streaming import StreamingModule
22
- from audiocraft.transformer import create_sin_embedding
23
- from audiocraft.utils.autocast import TorchAutocast
24
- from audiocraft.utils.utils import collate, length_to_mask
25
  from audiocraft.transformer import StreamingTransformer, create_norm_fn
26
  from dataclasses import dataclass
27
  from functools import partial
28
- import logging
29
- import math
30
- import typing as tp
31
-
32
-
33
  from torch import nn
34
-
35
  from audiocraft.utils import utils
36
- from audiocraft.codebooks_patterns import CodebooksPatternProvider
37
  from audiocraft.activations import get_activation_fn
38
 
39
 
40
-
41
 
42
 
43
  logger = logging.getLogger(__name__)
44
  TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
45
  ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
46
 
47
-
48
- class WavCondition(tp.NamedTuple):
49
- wav: torch.Tensor
50
- length: torch.Tensor
51
- sample_rate: tp.List[int]
52
- path: tp.List[tp.Optional[str]] = []
53
- seek_time: tp.List[tp.Optional[float]] = []
54
-
55
-
56
- class JointEmbedCondition(tp.NamedTuple):
57
- wav: torch.Tensor
58
- text: tp.List[tp.Optional[str]]
59
- length: torch.Tensor
60
- sample_rate: tp.List[int]
61
- path: tp.List[tp.Optional[str]] = []
62
- seek_time: tp.List[tp.Optional[float]] = []
63
-
64
-
65
- @dataclass
66
- class ConditioningAttributes:
67
- text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
68
- wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
69
- joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
70
-
71
- def __getitem__(self, item):
72
- return getattr(self, item)
73
-
74
- @property
75
- def text_attributes(self):
76
- return self.text.keys()
77
-
78
- @property
79
- def wav_attributes(self):
80
- return self.wav.keys()
81
-
82
- @property
83
- def joint_embed_attributes(self):
84
- return self.joint_embed.keys()
85
-
86
- @property
87
- def attributes(self):
88
- return {
89
- "text": self.text_attributes,
90
- "wav": self.wav_attributes,
91
- "joint_embed": self.joint_embed_attributes,
92
- }
93
-
94
- def to_flat_dict(self):
95
- return {
96
- **{f"text.{k}": v for k, v in self.text.items()},
97
- **{f"wav.{k}": v for k, v in self.wav.items()},
98
- **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
99
- }
100
-
101
- @classmethod
102
- def from_flat_dict(cls, x):
103
- out = cls()
104
- for k, v in x.items():
105
- kind, att = k.split(".")
106
- out[kind][att] = v
107
- return out
108
-
109
-
110
-
111
-
112
-
113
- def nullify_condition(condition: ConditionType, dim: int = 1):
114
- """Transform an input condition to a null condition.
115
- The way it is done by converting it to a single zero vector similarly
116
- to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
117
-
118
- Args:
119
- condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
120
- dim (int): The dimension that will be truncated (should be the time dimension)
121
- WARNING!: dim should not be the batch dimension!
122
- Returns:
123
- ConditionType: A tuple of null condition and mask
124
- """
125
- assert dim != 0, "dim cannot be the batch dimension!"
126
- assert isinstance(condition, tuple) and \
127
- isinstance(condition[0], torch.Tensor) and \
128
- isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
129
- cond, mask = condition
130
- B = cond.shape[0]
131
- last_dim = cond.dim() - 1
132
- out = cond.transpose(dim, last_dim)
133
- out = 0. * out[..., :1]
134
- out = out.transpose(dim, last_dim)
135
- mask = torch.zeros((B, 1), device=out.device).int()
136
- assert cond.dim() == out.dim()
137
- return out, mask
138
-
139
-
140
- def nullify_wav(cond: WavCondition) -> WavCondition:
141
- """Transform a WavCondition to a nullified WavCondition.
142
- It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
143
-
144
- Args:
145
- cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
146
- Returns:
147
- WavCondition: Nullified wav condition.
148
- """
149
- null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
150
- return WavCondition(
151
- wav=null_wav,
152
- length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
153
- sample_rate=cond.sample_rate,
154
- path=[None] * cond.wav.shape[0],
155
- seek_time=[None] * cond.wav.shape[0],
156
- )
157
-
158
-
159
- def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
160
- """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
161
- and replacing metadata by dummy attributes.
162
-
163
- Args:
164
- cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
165
- """
166
- null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
167
- return JointEmbedCondition(
168
- wav=null_wav, text=[None] * len(embed.text),
169
- length=torch.LongTensor([0]).to(embed.wav.device),
170
- sample_rate=embed.sample_rate,
171
- path=[None] * embed.wav.shape[0],
172
- seek_time=[0] * embed.wav.shape[0],
173
- )
174
-
175
-
176
- class Tokenizer:
177
- """Base tokenizer implementation
178
- (in case we want to introduce more advances tokenizers in the future).
179
- """
180
- def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
181
- raise NotImplementedError()
182
-
183
-
184
- class WhiteSpaceTokenizer(Tokenizer):
185
- """This tokenizer should be used for natural language descriptions.
186
- For example:
187
- ["he didn't, know he's going home.", 'shorter sentence'] =>
188
- [[78, 62, 31, 4, 78, 25, 19, 34],
189
- [59, 77, 0, 0, 0, 0, 0, 0]]
190
- """
191
- PUNCTUATION = "?:!.,;"
192
-
193
- def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
194
- lemma: bool = True, stopwords: bool = True) -> None:
195
- self.n_bins = n_bins
196
- self.pad_idx = pad_idx
197
- self.lemma = lemma
198
- self.stopwords = stopwords
199
- try:
200
- self.nlp = spacy.load(language)
201
- except IOError:
202
- spacy.cli.download(language) # type: ignore
203
- self.nlp = spacy.load(language)
204
-
205
- @tp.no_type_check
206
- def __call__(self, texts: tp.List[tp.Optional[str]],
207
- return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
208
- """Take a list of strings and convert them to a tensor of indices.
209
-
210
- Args:
211
- texts (list[str]): List of strings.
212
- return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
213
- Returns:
214
- tuple[torch.Tensor, torch.Tensor]:
215
- - Indices of words in the LUT.
216
- - And a mask indicating where the padding tokens are
217
- """
218
- output, lengths = [], []
219
- texts = deepcopy(texts)
220
- for i, text in enumerate(texts):
221
- # if current sample doesn't have a certain attribute, replace with pad token
222
- if text is None:
223
- output.append(torch.Tensor([self.pad_idx]))
224
- lengths.append(0)
225
- continue
226
-
227
- # convert numbers to words
228
- text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
229
- # normalize text
230
- text = self.nlp(text) # type: ignore
231
- # remove stopwords
232
- if self.stopwords:
233
- text = [w for w in text if not w.is_stop] # type: ignore
234
- # remove punctuation
235
- text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
236
- # lemmatize if needed
237
- text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
238
-
239
- texts[i] = " ".join(text)
240
- lengths.append(len(text))
241
- # convert to tensor
242
- tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
243
- output.append(tokens)
244
-
245
- mask = length_to_mask(torch.IntTensor(lengths)).int()
246
- padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
247
- if return_text:
248
- return padded_output, mask, texts # type: ignore
249
- return padded_output, mask
250
-
251
-
252
- class NoopTokenizer(Tokenizer):
253
- """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
254
- The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
255
- strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
256
- split it to ["Jeff", "Buckley"] and return an index per word.
257
-
258
- For example:
259
- ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
260
- ["Metal", "Rock", "Classical"] => [0, 223, 51]
261
- """
262
- def __init__(self, n_bins: int, pad_idx: int = 0):
263
- self.n_bins = n_bins
264
- self.pad_idx = pad_idx
265
-
266
- def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
267
- output, lengths = [], []
268
- for text in texts:
269
- # if current sample doesn't have a certain attribute, replace with pad token
270
- if text is None:
271
- output.append(self.pad_idx)
272
- lengths.append(0)
273
- else:
274
- output.append(hash_trick(text, self.n_bins))
275
- lengths.append(1)
276
-
277
- tokens = torch.LongTensor(output).unsqueeze(1)
278
- mask = length_to_mask(torch.IntTensor(lengths)).int()
279
- return tokens, mask
280
-
281
-
282
- class BaseConditioner(nn.Module):
283
- """Base model for all conditioner modules.
284
- We allow the output dim to be different than the hidden dim for two reasons:
285
- 1) keep our LUTs small when the vocab is large;
286
- 2) make all condition dims consistent.
287
-
288
- Args:
289
- dim (int): Hidden dim of the model.
290
- output_dim (int): Output dim of the conditioner.
291
- """
292
- def __init__(self, dim: int, output_dim: int):
293
- super().__init__()
294
- self.dim = dim
295
- self.output_dim = output_dim
296
- self.output_proj = nn.Linear(dim, output_dim)
297
-
298
-
299
-
300
- def forward(self, inputs: tp.Any) -> ConditionType:
301
- """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
302
- Outputs a ConditionType, after the input data was embedded as a dense vector.
303
-
304
- Returns:
305
- ConditionType:
306
- - A tensor of size [B, T, D] where B is the batch size, T is the length of the
307
- output embedding and D is the dimension of the embedding.
308
- - And a mask indicating where the padding tokens.
309
- """
310
- raise NotImplementedError()
311
-
312
-
313
- class TextConditioner(BaseConditioner):
314
- ...
315
-
316
-
317
-
318
-
319
-
320
- class T5Conditioner(TextConditioner):
321
- """T5-based TextConditioner.
322
-
323
- Args:
324
- name (str): Name of the T5 model.
325
- output_dim (int): Output dim of the conditioner.
326
- finetune (bool): Whether to fine-tune T5 at train time.
327
- device (str): Device for T5 Conditioner.
328
- autocast_dtype (tp.Optional[str], optional): Autocast dtype.
329
- word_dropout (float, optional): Word dropout probability.
330
- normalize_text (bool, optional): Whether to apply text normalization.
331
- """
332
- MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
333
- "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
334
- "google/flan-t5-xl", "google/flan-t5-xxl"]
335
- MODELS_DIMS = {
336
- "t5-small": 512,
337
- "t5-base": 768,
338
- "t5-large": 1024,
339
- "t5-3b": 1024,
340
- "t5-11b": 1024,
341
- "google/flan-t5-small": 512,
342
- "google/flan-t5-base": 768,
343
- "google/flan-t5-large": 1024,
344
- "google/flan-t5-3b": 1024,
345
- "google/flan-t5-11b": 1024,
346
- }
347
-
348
- def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
349
- autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
350
- normalize_text: bool = False):
351
- assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
352
- super().__init__(self.MODELS_DIMS[name], output_dim)
353
- self.device = device
354
- self.name = name
355
- self.finetune = finetune
356
- self.word_dropout = word_dropout
357
- if autocast_dtype is None or self.device == 'cpu':
358
- self.autocast = TorchAutocast(enabled=False)
359
- if self.device != 'cpu':
360
- logger.warning("T5 has no autocast, this might lead to NaN")
361
- else:
362
- dtype = getattr(torch, autocast_dtype)
363
- assert isinstance(dtype, torch.dtype)
364
- logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
365
- self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
366
- # Let's disable logging temporarily because T5 will vomit some errors otherwise.
367
- # thanks https://gist.github.com/simon-weber/7853144
368
- previous_level = logging.root.manager.disable
369
- logging.disable(logging.ERROR)
370
- with warnings.catch_warnings():
371
- warnings.simplefilter("ignore")
372
- try:
373
- self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
374
- t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
375
- finally:
376
- logging.disable(previous_level)
377
- if finetune:
378
- self.t5 = t5
379
- else:
380
- # this makes sure that the t5 models is not part
381
- # of the saved checkpoint
382
- self.__dict__['t5'] = t5.to(device)
383
-
384
- self.normalize_text = normalize_text
385
- if normalize_text:
386
- self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
387
-
388
- def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
389
- # if current sample doesn't have a certain attribute, replace with empty string
390
- entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
391
- if self.normalize_text:
392
- _, _, entries = self.text_normalizer(entries, return_text=True)
393
- if self.word_dropout > 0. and self.training:
394
- new_entries = []
395
- for entry in entries:
396
- words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
397
- new_entries.append(" ".join(words))
398
- entries = new_entries
399
-
400
- empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
401
-
402
- inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
403
- mask = inputs['attention_mask']
404
- mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
405
- return inputs
406
-
407
- def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
408
- mask = inputs['attention_mask']
409
- with torch.set_grad_enabled(self.finetune), self.autocast:
410
- embeds = self.t5(**inputs).last_hidden_state
411
- embeds = self.output_proj(embeds.to(self.output_proj.weight))
412
- embeds = (embeds * mask.unsqueeze(-1))
413
- return embeds, mask
414
-
415
-
416
-
417
-
418
-
419
-
420
-
421
-
422
- class JointEmbeddingConditioner(BaseConditioner):
423
- """Joint embedding conditioning supporting both audio or text conditioning.
424
-
425
- Args:
426
- dim (int): Dimension.
427
- output_dim (int): Output dimension.
428
- device (str): Device.
429
- attribute (str): Attribute used by the conditioner.
430
- autocast_dtype (str): Autocast for the conditioner.
431
- quantize (bool): Whether to quantize the CLAP embedding.
432
- n_q (int): Number of residual quantizers (used if quantize is true).
433
- bins (int): Quantizers' codebooks size (used if quantize is true).
434
- kwargs: Additional parameters for residual vector quantizer.
435
- """
436
- def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
437
- autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
438
- n_q: int = 12, bins: int = 1024, **kwargs):
439
- super().__init__(dim=dim, output_dim=output_dim)
440
- self.device = device
441
- self.attribute = attribute
442
- if autocast_dtype is None or device == 'cpu':
443
- self.autocast = TorchAutocast(enabled=False)
444
- logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
445
- else:
446
- dtype = getattr(torch, autocast_dtype)
447
- assert isinstance(dtype, torch.dtype)
448
- logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
449
- self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
450
- # residual vector quantizer to discretize the conditioned embedding
451
- self.quantizer=None
452
- if quantize:
453
- print('\n\n\n\nWANTS TO QUANTIZE on Inference\n\n\n\n')
454
- # self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
455
-
456
- def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
457
- """Get joint embedding in latent space from the inputs.
458
-
459
- Returns:
460
- tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
461
- and corresponding empty indexes.
462
- """
463
- raise NotImplementedError()
464
-
465
- def forward(self, x: JointEmbedCondition) -> ConditionType:
466
- with self.autocast:
467
- embed, empty_idx = self._get_embed(x)
468
- if self.quantizer is not None:
469
- embed = embed.view(-1, self.dim, 1)
470
- q_res = self.quantizer(embed, frame_rate=1)
471
- out_embed = q_res.x.view(-1, self.dim)
472
- else:
473
- out_embed = embed
474
- out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
475
- mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
476
- mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
477
- out_embed = (out_embed * mask.unsqueeze(-1))
478
- return out_embed, mask
479
-
480
- def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
481
- return x
482
-
483
-
484
-
485
-
486
-
487
-
488
-
489
-
490
-
491
-
492
-
493
- class ConditioningProvider(nn.Module):
494
- """Prepare and provide conditions given all the supported conditioners.
495
-
496
- Args:
497
- conditioners (dict): Dictionary of conditioners.
498
- device (torch.device or str, optional): Device for conditioners and output condition types.
499
- """
500
- def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
501
- super().__init__()
502
- self.device = device
503
- self.conditioners = nn.ModuleDict(conditioners)
504
-
505
- @property
506
- def joint_embed_conditions(self):
507
- return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
508
-
509
- @property
510
- def has_joint_embed_conditions(self):
511
- return len(self.joint_embed_conditions) > 0
512
-
513
- @property
514
- def text_conditions(self):
515
- return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
516
-
517
- @property
518
- def wav_conditions(self):
519
- return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
520
-
521
- @property
522
- def has_wav_condition(self):
523
- return len(self.wav_conditions) > 0
524
-
525
- def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
526
- """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
527
- The output is for example:
528
- {
529
- "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
530
- "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
531
- ...
532
- }
533
-
534
- Args:
535
- tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
536
- """
537
- output = {}
538
- for attribute, inputs in tokenized.items():
539
- condition, mask = self.conditioners[attribute](inputs)
540
- output[attribute] = (condition, mask)
541
- return output
542
-
543
- def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
544
- """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
545
- are the attributes and the values are the aggregated input per attribute.
546
- For example:
547
- Input:
548
- [
549
- ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
550
- ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
551
- ]
552
- Output:
553
- {
554
- "genre": ["Rock", "Hip-hop"],
555
- "description": ["A rock song with a guitar solo", "A hip-hop verse"]
556
- }
557
-
558
- Args:
559
- samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
560
- Returns:
561
- dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
562
- """
563
- out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
564
- texts = [x.text for x in samples]
565
- for text in texts:
566
- for condition in self.text_conditions:
567
- out[condition].append(text[condition])
568
- return out
569
-
570
- def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
571
- """Generate a dict where the keys are attributes by which we fetch similar wavs,
572
- and the values are Tensors of wavs according to said attributes.
573
-
574
- *Note*: by the time the samples reach this function, each sample should have some waveform
575
- inside the "wav" attribute. It should be either:
576
- 1. A real waveform
577
- 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
578
- 3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
579
-
580
- Args:
581
- samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
582
- Returns:
583
- dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
584
- """
585
- wavs = defaultdict(list)
586
- lengths = defaultdict(list)
587
- sample_rates = defaultdict(list)
588
- paths = defaultdict(list)
589
- seek_times = defaultdict(list)
590
- out: tp.Dict[str, WavCondition] = {}
591
-
592
- for sample in samples:
593
- for attribute in self.wav_conditions:
594
- wav, length, sample_rate, path, seek_time = sample.wav[attribute]
595
- assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
596
- assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
597
- # mono-channel conditioning
598
- wav = wav.mean(1, keepdim=True) # [1, 1, T]
599
- wavs[attribute].append(wav.flatten()) # [T]
600
- lengths[attribute].append(length)
601
- sample_rates[attribute].extend(sample_rate)
602
- paths[attribute].extend(path)
603
- seek_times[attribute].extend(seek_time)
604
-
605
- # stack all wavs to a single tensor
606
- for attribute in self.wav_conditions:
607
- stacked_wav, _ = collate(wavs[attribute], dim=0)
608
- out[attribute] = WavCondition(
609
- stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
610
- paths[attribute], seek_times[attribute])
611
-
612
- return out
613
-
614
- def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
615
- """Generate a dict where the keys are attributes by which we compute joint embeddings,
616
- and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
617
-
618
- Args:
619
- samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
620
- Returns:
621
- A dictionary mapping an attribute name to joint embeddings.
622
- """
623
- texts = defaultdict(list)
624
- wavs = defaultdict(list)
625
- lengths = defaultdict(list)
626
- sample_rates = defaultdict(list)
627
- paths = defaultdict(list)
628
- seek_times = defaultdict(list)
629
- channels: int = 0
630
-
631
- out = {}
632
- for sample in samples:
633
- for attribute in self.joint_embed_conditions:
634
- wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
635
- assert wav.dim() == 3
636
- if channels == 0:
637
- channels = wav.size(1)
638
- else:
639
- assert channels == wav.size(1), "not all audio has same number of channels in batch"
640
- assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
641
- wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
642
- wavs[attribute].append(wav)
643
- texts[attribute].extend(text)
644
- lengths[attribute].append(length)
645
- sample_rates[attribute].extend(sample_rate)
646
- paths[attribute].extend(path)
647
- seek_times[attribute].extend(seek_time)
648
-
649
- for attribute in self.joint_embed_conditions:
650
- stacked_texts = texts[attribute]
651
- stacked_paths = paths[attribute]
652
- stacked_seek_times = seek_times[attribute]
653
- stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
654
- stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
655
- stacked_sample_rates = sample_rates[attribute]
656
- stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
657
- assert stacked_lengths.size(0) == stacked_wavs.size(0)
658
- assert len(stacked_sample_rates) == stacked_wavs.size(0)
659
- assert len(stacked_texts) == stacked_wavs.size(0)
660
- out[attribute] = JointEmbedCondition(
661
- text=stacked_texts, wav=stacked_wavs,
662
- length=stacked_lengths, sample_rate=stacked_sample_rates,
663
- path=stacked_paths, seek_time=stacked_seek_times)
664
-
665
- return out
666
-
667
-
668
- class ConditionFuser(StreamingModule):
669
- """Condition fuser handles the logic to combine the different conditions
670
- to the actual model input.
671
-
672
- Args:
673
- fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
674
- each condition. For example:
675
- {
676
- "prepend": ["description"],
677
- "sum": ["genre", "bpm"],
678
- "cross": ["description"],
679
- }
680
- cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
681
- cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
682
- """
683
- FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
684
-
685
- def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
686
- cross_attention_pos_emb_scale: float = 1.0):
687
- super().__init__()
688
- assert all(
689
- [k in self.FUSING_METHODS for k in fuse2cond.keys()]
690
- ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
691
- self.cross_attention_pos_emb = cross_attention_pos_emb
692
- self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
693
- self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
694
- self.cond2fuse: tp.Dict[str, str] = {}
695
- for fuse_method, conditions in fuse2cond.items():
696
- for condition in conditions:
697
- self.cond2fuse[condition] = fuse_method
698
-
699
- def forward(
700
- self,
701
- input: torch.Tensor,
702
- conditions: tp.Dict[str, ConditionType]
703
- ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
704
- """Fuse the conditions to the provided model input.
705
-
706
- Args:
707
- input (torch.Tensor): Transformer input.
708
- conditions (dict[str, ConditionType]): Dict of conditions.
709
- Returns:
710
- tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
711
- after the conditions have been fused. The second output tensor is the tensor
712
- used for cross-attention or None if no cross attention inputs exist.
713
- """
714
- B, T, _ = input.shape
715
-
716
- if 'offsets' in self._streaming_state:
717
- first_step = False
718
- offsets = self._streaming_state['offsets']
719
- else:
720
- first_step = True
721
- offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
722
-
723
- assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
724
- f"given conditions contain unknown attributes for fuser, " \
725
- f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
726
- cross_attention_output = None
727
- for cond_type, (cond, cond_mask) in conditions.items():
728
- op = self.cond2fuse[cond_type]
729
- if op == 'sum':
730
- input += cond
731
- elif op == 'input_interpolate':
732
- cond = einops.rearrange(cond, "b t d -> b d t")
733
- cond = F.interpolate(cond, size=input.shape[1])
734
- input += einops.rearrange(cond, "b d t -> b t d")
735
- elif op == 'prepend':
736
- if first_step:
737
- input = torch.cat([cond, input], dim=1)
738
- elif op == 'cross':
739
- if cross_attention_output is not None:
740
- cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
741
- else:
742
- cross_attention_output = cond
743
- else:
744
- raise ValueError(f"unknown op ({op})")
745
-
746
- if self.cross_attention_pos_emb and cross_attention_output is not None:
747
- print('SIN EMBED')
748
- positions = torch.arange(
749
- cross_attention_output.shape[1],
750
- device=cross_attention_output.device
751
- ).view(1, -1, 1)
752
- pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
753
- cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
754
-
755
- if self._is_streaming:
756
- self._streaming_state['offsets'] = offsets + T
757
-
758
- return input, cross_attention_output
759
-
760
-
761
-
762
- # ============================================== From LM.py
763
-
764
-
765
-
766
- logger = logging.getLogger(__name__)
767
  ConditionTensors = tp.Dict[str, ConditionType]
768
  CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
769
 
@@ -876,8 +134,11 @@ class LMModel(StreamingModule):
876
  two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
877
  **kwargs: Additional parameters for the transformer encoder.
878
  """
879
- def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
880
- fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
 
 
 
881
  hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
882
  emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
883
  weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
@@ -952,27 +213,11 @@ class LMModel(StreamingModule):
952
  def num_codebooks(self) -> int:
953
  return self.n_q
954
 
955
- def forward(self, sequence: torch.Tensor,
956
- conditions: tp.List[ConditioningAttributes],
957
- condition_tensors: tp.Optional[ConditionTensors] = None,
958
- stage: int = -1) -> torch.Tensor:
959
- """Apply language model on sequence and conditions.
960
- Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
961
- S the sequence steps, return the logits with shape [B, card, K, S].
962
-
963
- Args:
964
- indices (torch.Tensor): Indices of the codes to model.
965
- conditions (list of ConditioningAttributes): Conditions to use when modeling
966
- the given codes. Note that when evaluating multiple time with the same conditioning
967
- you should pre-compute those and pass them as `condition_tensors`.
968
- condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
969
- tensors, see `conditions`.
970
- stage (int): The codebook level that is being predicted. Relevant for MAGNeT
971
- in which prediction is done in a codebook-by-codebook manner.
972
- Takes values in range(n_q), and ignored by default.
973
- Returns:
974
- torch.Tensor: Logits.
975
- """
976
  B, K, S = sequence.shape
977
  assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
978
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
@@ -983,8 +228,8 @@ class LMModel(StreamingModule):
983
  condition_tensors = self.condition_provider(tokenized)
984
  else:
985
  assert not conditions, "Shouldn't pass both conditions and condition_tensors."
986
-
987
- input_, cross_attention_input = self.fuser(input_, condition_tensors)
988
 
989
  out = self.transformer(input_, cross_attention_src=cross_attention_input,
990
  src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
@@ -999,60 +244,6 @@ class LMModel(StreamingModule):
999
 
1000
  return logits # [B, K, S, card]
1001
 
1002
- def compute_predictions(
1003
- self, codes: torch.Tensor,
1004
- conditions: tp.List[ConditioningAttributes],
1005
- condition_tensors: tp.Optional[ConditionTensors] = None,
1006
- stage: int = -1,
1007
- keep_only_valid_steps: bool = True) -> LMOutput:
1008
- """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
1009
- forward using the specified codes interleaving pattern.
1010
-
1011
- Args:
1012
- codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
1013
- K the number of codebooks and T the number of timesteps.
1014
- conditions (list of ConditioningAttributes): conditionings to use when modeling
1015
- the given codes. Note that when evaluating multiple time with the same conditioning
1016
- you should pre-compute those and pass them as `condition_tensors`.
1017
- condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
1018
- tensors, see `conditions`.
1019
- stage (int): The codebook level that is being predicted. Relevant for MAGNeT
1020
- in which prediction is done in a codebook-by-codebook manner.
1021
- Takes values in range(n_q), and ignored by default.
1022
- keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
1023
- Steps that are beyond valid steps will be replaced by the special_token in that case.
1024
- Returns:
1025
- LMOutput: Language model outputs
1026
- logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
1027
- i.e. the first item corresponds to logits to predict the first code, meaning that
1028
- no additional shifting of codes and logits is required.
1029
- mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
1030
- Given the specified interleaving strategies, parts of the logits and codes should
1031
- not be considered as valid predictions because of invalid context.
1032
- """
1033
- B, K, T = codes.shape
1034
- codes = codes.contiguous()
1035
- # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
1036
- # what is the T is it 2048 ?
1037
- # and then what is pattern -> another function?
1038
- pattern = self.pattern_provider.get_pattern(T)
1039
- sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
1040
- codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
1041
- )
1042
-
1043
- # apply model on pattern sequence
1044
- model = self if self._fsdp is None else self._fsdp
1045
- logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card]
1046
- # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
1047
- # and provide the corresponding mask over invalid positions of tokens
1048
- logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
1049
- # note: we use nans as special token to make it obvious if we feed unexpected logits
1050
- logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
1051
- logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps
1052
- )
1053
- logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
1054
- logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
1055
- return LMOutput(logits, logits_mask)
1056
 
1057
  def _sample_next_token(self,
1058
  sequence,
@@ -1127,11 +318,12 @@ class LMModel(StreamingModule):
1127
 
1128
  return next_token
1129
 
 
1130
  @torch.no_grad()
1131
  def generate(self,
1132
- prompt: tp.Optional[torch.Tensor] = None,
1133
- conditions: tp.List[ConditioningAttributes] = [],
1134
- num_samples: tp.Optional[int] = None,
1135
  max_gen_len: int = 256,
1136
  use_sampling: bool = True,
1137
  temp: float = 1.0,
@@ -1143,25 +335,12 @@ class LMModel(StreamingModule):
1143
  check: bool = False,
1144
  callback: tp.Optional[tp.Callable[[int, int], None]] = None,
1145
  **kwargs) -> torch.Tensor:
1146
- """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
1147
- be performed in a greedy fashion or using sampling with top K and top P strategies.
1148
 
1149
  Args:
1150
- prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
1151
- conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
1152
- num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
1153
- max_gen_len (int): Maximum generation length.
1154
- use_sampling (bool): Whether to use a sampling strategy or not.
1155
- temp (float): Sampling temperature.
1156
- top_k (int): K for "top-k" sampling.
1157
- top_p (float): P for "top-p" sampling.
1158
- cfg_coeff (float, optional): Classifier-free guidance coefficient.
1159
- two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
1160
- remove_prompts (bool): Whether to remove prompts from generation or not.
1161
- check (bool): Whether to apply further checks on generated sequence.
1162
- callback (Callback, optional): Callback function to report generation progress.
1163
  Returns:
1164
- torch.Tensor: Generated tokens.
1165
  """
1166
  assert not self.training, "generation shouldn't be used in training mode."
1167
  first_param = next(iter(self.parameters()))
@@ -1190,20 +369,13 @@ class LMModel(StreamingModule):
1190
  # the padding structure is exactly the same between train and test.
1191
  # With a batch size of 1, this can be slower though.
1192
  cfg_conditions: CFGConditions
1193
- two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
1194
- if conditions:
1195
- null_conditions = conditions
1196
- if two_step_cfg:
1197
- cfg_conditions = (
1198
- self.condition_provider(self.condition_provider.tokenize(conditions)),
1199
- self.condition_provider(self.condition_provider.tokenize(null_conditions)),
1200
- )
1201
- else:
1202
- conditions = conditions + null_conditions
1203
- tokenized = self.condition_provider.tokenize(conditions)
1204
- cfg_conditions = self.condition_provider(tokenized)
1205
- else:
1206
- cfg_conditions = {}
1207
 
1208
  if prompt is None:
1209
  assert num_samples > 0
@@ -1222,18 +394,26 @@ class LMModel(StreamingModule):
1222
 
1223
  gen_codes[..., :start_offset] = prompt
1224
  # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
1225
- gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
1226
- # retrieve the start_offset in the sequence:
1227
- # it is the first sequence step that contains the `start_offset` timestep
1228
  start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
 
1229
  assert start_offset_sequence is not None
1230
 
1231
  with self.streaming():
1232
  unconditional_state = self.get_streaming_state()
1233
  prev_offset = 0
1234
  gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
 
 
 
 
 
 
 
1235
  for offset in range(start_offset_sequence, gen_sequence_len):
1236
  # get current sequence (note that the streaming API is providing the caching over previous offsets)
 
1237
  curr_sequence = gen_sequence[..., prev_offset:offset]
1238
  curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
1239
  if check:
@@ -1268,11 +448,13 @@ class LMModel(StreamingModule):
1268
  callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
1269
  unconditional_state.clear()
1270
 
1271
- out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
1272
 
1273
  out_start_offset = start_offset if remove_prompts else 0
1274
  out_codes = out_codes[..., out_start_offset:max_gen_len]
1275
 
1276
  # ensure the returned codes are all valid
 
1277
  # assert (out_codes >= 0).all() and (out_codes <= self.card).all()
 
1278
  return out_codes
 
 
 
 
 
1
  from dataclasses import dataclass, field
2
  from itertools import chain
3
  import logging
4
  import math
 
 
5
  import re
6
  import typing as tp
 
 
 
 
 
7
  import torch
8
  import torch.nn.functional as F
 
9
  from audiocraft.streaming import StreamingModule
 
 
 
10
  from audiocraft.transformer import StreamingTransformer, create_norm_fn
11
  from dataclasses import dataclass
12
  from functools import partial
 
 
 
 
 
13
  from torch import nn
 
14
  from audiocraft.utils import utils
 
15
  from audiocraft.activations import get_activation_fn
16
 
17
 
18
+ # ============================================== From LM.py
19
 
20
 
21
  logger = logging.getLogger(__name__)
22
  TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
23
  ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ConditionTensors = tp.Dict[str, ConditionType]
26
  CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
27
 
 
134
  two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
135
  **kwargs: Additional parameters for the transformer encoder.
136
  """
137
+ def __init__(self,
138
+ pattern_provider,
139
+ condition_provider,
140
+ fuser,
141
+ n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
142
  hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
143
  emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
144
  weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
 
213
  def num_codebooks(self) -> int:
214
  return self.n_q
215
 
216
+ def forward(self,
217
+ sequence,
218
+ conditions,
219
+ condition_tensors=None,
220
+ stage = -1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  B, K, S = sequence.shape
222
  assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
223
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
 
228
  condition_tensors = self.condition_provider(tokenized)
229
  else:
230
  assert not conditions, "Shouldn't pass both conditions and condition_tensors."
231
+
232
+ input_, cross_attention_input = self.fuser(input_, condition_tensors) # DEFINE conditioners.py
233
 
234
  out = self.transformer(input_, cross_attention_src=cross_attention_input,
235
  src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
 
244
 
245
  return logits # [B, K, S, card]
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  def _sample_next_token(self,
249
  sequence,
 
318
 
319
  return next_token
320
 
321
+ # GENERATE class revert_codebook_patterns()
322
  @torch.no_grad()
323
  def generate(self,
324
+ prompt = None,
325
+ conditions = [],
326
+ num_samples = None,
327
  max_gen_len: int = 256,
328
  use_sampling: bool = True,
329
  temp: float = 1.0,
 
335
  check: bool = False,
336
  callback: tp.Optional[tp.Callable[[int, int], None]] = None,
337
  **kwargs) -> torch.Tensor:
338
+ """Default generation takes random token of top_250 logits
 
339
 
340
  Args:
341
+
 
 
 
 
 
 
 
 
 
 
 
 
342
  Returns:
343
+ torch.Tensor: tokens
344
  """
345
  assert not self.training, "generation shouldn't be used in training mode."
346
  first_param = next(iter(self.parameters()))
 
369
  # the padding structure is exactly the same between train and test.
370
  # With a batch size of 1, this can be slower though.
371
  cfg_conditions: CFGConditions
372
+ # two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
373
+
374
+ null_conditions = conditions
375
+ conditions = conditions + null_conditions
376
+ tokenized = self.condition_provider.tokenize(conditions)
377
+ cfg_conditions = self.condition_provider(tokenized)
378
+
 
 
 
 
 
 
 
379
 
380
  if prompt is None:
381
  assert num_samples > 0
 
394
 
395
  gen_codes[..., :start_offset] = prompt
396
  # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
397
+ gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
398
+
 
399
  start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
400
+ # print('\n=', start_offset_sequence, '\n=') # 1
401
  assert start_offset_sequence is not None
402
 
403
  with self.streaming():
404
  unconditional_state = self.get_streaming_state()
405
  prev_offset = 0
406
  gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
407
+
408
+ # --
409
+ # print(mask.shape, mask.sum(), 'MSK LM')
410
+ # torch.Size([4, 39]) tensor(140, device='cuda:0') MSK LM ? Fully 1 normal no special token
411
+ # --
412
+
413
+
414
  for offset in range(start_offset_sequence, gen_sequence_len):
415
  # get current sequence (note that the streaming API is providing the caching over previous offsets)
416
+
417
  curr_sequence = gen_sequence[..., prev_offset:offset]
418
  curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
419
  if check:
 
448
  callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
449
  unconditional_state.clear()
450
 
451
+ out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
452
 
453
  out_start_offset = start_offset if remove_prompts else 0
454
  out_codes = out_codes[..., out_start_offset:max_gen_len]
455
 
456
  # ensure the returned codes are all valid
457
+
458
  # assert (out_codes >= 0).all() and (out_codes <= self.card).all()
459
+
460
  return out_codes
audiocraft/loaders.py CHANGED
@@ -101,7 +101,8 @@ def _delete_param(cfg: DictConfig, full_name: str):
101
  OmegaConf.set_struct(cfg, True)
102
 
103
 
104
- def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
 
105
  pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
106
  cfg = OmegaConf.create(pkg['xp.cfg'])
107
  cfg.device = str(device)
 
101
  OmegaConf.set_struct(cfg, True)
102
 
103
 
104
+ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
105
+ cache_dir: tp.Optional[str] = None):
106
  pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
107
  cfg = OmegaConf.create(pkg['xp.cfg'])
108
  cfg.device = str(device)
demo.py CHANGED
@@ -1,15 +1,14 @@
1
  from audiocraft.audiogen import AudioGen #, audio_write
2
- import audiofile
3
- import numpy as np
4
 
5
  print('\n\n\n\n___________________')
6
 
7
- txt = 'car'
8
 
9
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
- sound_generator.set_generation_params(duration=1) # why is generating so long at 14 seconds
11
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7
14
 
15
- audiofile.write('_audio1_.wav', x, 16000)
 
1
  from audiocraft.audiogen import AudioGen #, audio_write
2
+
 
3
 
4
  print('\n\n\n\n___________________')
5
 
6
+ txt = 'austrian music'
7
 
8
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
9
+ sound_generator.set_generation_params(duration=4.7) # why is generating so long at 14 seconds
10
 
11
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
12
  x /= np.abs(x).max() + 1e-7
13
 
14
+ audiofile.write('del_seane.wav', x, 16000)