Dionyssos commited on
Commit
531e776
·
1 Parent(s): 06aa0fc

del xtr funs

Browse files
audiocraft/builders.py CHANGED
@@ -28,10 +28,8 @@ from .codebooks_patterns import (
28
  )
29
  from .conditioners import (
30
  BaseConditioner,
31
- CLAPEmbeddingConditioner,
32
  ConditionFuser,
33
  ConditioningProvider,
34
- LUTConditioner,
35
  T5Conditioner,
36
  )
37
  from .unet import DiffusionUnet
 
28
  )
29
  from .conditioners import (
30
  BaseConditioner,
 
31
  ConditionFuser,
32
  ConditioningProvider,
 
33
  T5Conditioner,
34
  )
35
  from .unet import DiffusionUnet
audiocraft/conditioners.py CHANGED
@@ -19,7 +19,7 @@ import soundfile
19
  import einops
20
  from num2words import num2words
21
  import spacy
22
- from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
23
  import torch
24
  from torch import nn
25
  import torch.nn.functional as F
@@ -317,39 +317,7 @@ class TextConditioner(BaseConditioner):
317
  ...
318
 
319
 
320
- class LUTConditioner(TextConditioner):
321
- """Lookup table TextConditioner.
322
 
323
- Args:
324
- n_bins (int): Number of bins.
325
- dim (int): Hidden dim of the model (text-encoder/LUT).
326
- output_dim (int): Output dim of the conditioner.
327
- tokenizer (str): Name of the tokenizer.
328
- pad_idx (int, optional): Index for padding token. Defaults to 0.
329
- """
330
- def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
331
- super().__init__(dim, output_dim)
332
- self.embed = nn.Embedding(n_bins, dim)
333
- self.tokenizer: Tokenizer
334
- if tokenizer == 'whitespace':
335
- self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
336
- elif tokenizer == 'noop':
337
- self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
338
- else:
339
- raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
340
-
341
- def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
342
- device = self.embed.weight.device
343
- tokens, mask = self.tokenizer(x)
344
- tokens, mask = tokens.to(device), mask.to(device)
345
- return tokens, mask
346
-
347
- def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
348
- tokens, mask = inputs
349
- embeds = self.embed(tokens)
350
- embeds = self.output_proj(embeds)
351
- embeds = (embeds * mask.unsqueeze(-1))
352
- return embeds, mask
353
 
354
 
355
  class T5Conditioner(TextConditioner):
@@ -448,357 +416,7 @@ class T5Conditioner(TextConditioner):
448
  return embeds, mask
449
 
450
 
451
- class WaveformConditioner(BaseConditioner):
452
- """Base class for all conditioners that take a waveform as input.
453
- Classes that inherit must implement `_get_wav_embedding` that outputs
454
- a continuous tensor, and `_downsampling_factor` that returns the down-sampling
455
- factor of the embedding model.
456
-
457
- Args:
458
- dim (int): The internal representation dimension.
459
- output_dim (int): Output dimension.
460
- device (tp.Union[torch.device, str]): Device.
461
- """
462
- def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
463
- super().__init__(dim, output_dim)
464
- self.device = device
465
- # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
466
- self._use_masking = True
467
-
468
- def tokenize(self, x: WavCondition) -> WavCondition:
469
- wav, length, sample_rate, path, seek_time = x
470
- assert length is not None
471
- return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
472
-
473
- def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
474
- """Gets as input a WavCondition and returns a dense embedding."""
475
- raise NotImplementedError()
476
-
477
- def _downsampling_factor(self):
478
- """Returns the downsampling factor of the embedding model."""
479
- raise NotImplementedError()
480
-
481
- def forward(self, x: WavCondition) -> ConditionType:
482
- """Extract condition embedding and mask from a waveform and its metadata.
483
- Args:
484
- x (WavCondition): Waveform condition containing raw waveform and metadata.
485
- Returns:
486
- ConditionType: a dense vector representing the conditioning along with its mask
487
- """
488
- wav, lengths, *_ = x
489
- with torch.no_grad():
490
- embeds = self._get_wav_embedding(x)
491
- embeds = embeds.to(self.output_proj.weight)
492
- embeds = self.output_proj(embeds)
493
-
494
- if lengths is not None and self._use_masking:
495
- lengths = lengths / self._downsampling_factor()
496
- mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
497
- else:
498
- mask = torch.ones_like(embeds[..., 0])
499
- embeds = (embeds * mask.unsqueeze(-1))
500
- return embeds, mask
501
-
502
-
503
-
504
-
505
 
506
- class JointEmbeddingConditioner(BaseConditioner):
507
- """Joint embedding conditioning supporting both audio or text conditioning.
508
-
509
- Args:
510
- dim (int): Dimension.
511
- output_dim (int): Output dimension.
512
- device (str): Device.
513
- attribute (str): Attribute used by the conditioner.
514
- autocast_dtype (str): Autocast for the conditioner.
515
- quantize (bool): Whether to quantize the CLAP embedding.
516
- n_q (int): Number of residual quantizers (used if quantize is true).
517
- bins (int): Quantizers' codebooks size (used if quantize is true).
518
- kwargs: Additional parameters for residual vector quantizer.
519
- """
520
- def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
521
- autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
522
- n_q: int = 12, bins: int = 1024, **kwargs):
523
- super().__init__(dim=dim, output_dim=output_dim)
524
- self.device = device
525
- self.attribute = attribute
526
- if autocast_dtype is None or device == 'cpu':
527
- self.autocast = TorchAutocast(enabled=False)
528
- logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
529
- else:
530
- dtype = getattr(torch, autocast_dtype)
531
- assert isinstance(dtype, torch.dtype)
532
- logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
533
- self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
534
- # residual vector quantizer to discretize the conditioned embedding
535
- self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
536
- if quantize:
537
- self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
538
-
539
- def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
540
- """Get joint embedding in latent space from the inputs.
541
-
542
- Returns:
543
- tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
544
- and corresponding empty indexes.
545
- """
546
- raise NotImplementedError()
547
-
548
- def forward(self, x: JointEmbedCondition) -> ConditionType:
549
- with self.autocast:
550
- embed, empty_idx = self._get_embed(x)
551
- if self.quantizer is not None:
552
- embed = embed.view(-1, self.dim, 1)
553
- q_res = self.quantizer(embed, frame_rate=1)
554
- out_embed = q_res.x.view(-1, self.dim)
555
- else:
556
- out_embed = embed
557
- out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
558
- mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
559
- mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
560
- out_embed = (out_embed * mask.unsqueeze(-1))
561
- return out_embed, mask
562
-
563
- def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
564
- return x
565
-
566
-
567
- class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
568
- """Joint Embedding conditioner based on pre-trained CLAP model.
569
-
570
- This CLAP-based conditioner supports a caching mechanism
571
- over the computed embeddings for faster training.
572
-
573
- Args:
574
- dim (int): Dimension.
575
- output_dim (int): Output dimension.
576
- device (str): Device.
577
- attribute (str): Attribute used by the conditioner.
578
- quantize (bool): Whether to quantize the CLAP embedding.
579
- n_q (int): Number of residual quantizers (used if quantize is true).
580
- bins (int): Quantizers' codebooks size (used if quantize is true).
581
- checkpoint (str): Path to CLAP checkpoint.
582
- model_arch (str): CLAP model architecture.
583
- enable_fusion (bool): Enable fusion for CLAP model.
584
- sample_rate (int): Sample rate used by CLAP model.
585
- max_audio_length (float): Maximum audio length for CLAP model.
586
- audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
587
- normalize (bool): Whether to normalize the CLAP embedding.
588
- text_p (float): Probability of using text representation instead of audio at train time.
589
- batch_size (Optional[int]): Batch size for CLAP embedding computation.
590
- autocast_dtype (str): Autocast for the conditioner.
591
- cache_path (Optional[str]): Path for pre-computed embeddings caching.
592
- kwargs: Additional parameters for residual vector quantizer.
593
- """
594
- def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
595
- quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
596
- enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
597
- normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
598
- autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
599
- try:
600
- import laion_clap # type: ignore
601
- except ImportError:
602
- raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
603
- warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
604
- "Please retrain all models.")
605
- # checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
606
- clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
607
- clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
608
- load_clap_state_dict(clap_model, checkpoint)
609
- clap_model.eval()
610
- clap_model.to(device)
611
- super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
612
- autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
613
- **kwargs)
614
- self.checkpoint = checkpoint
615
- self.enable_fusion = enable_fusion
616
- self.model_arch = model_arch
617
- self.clap: laion_clap.CLAP_Module
618
- self.clap_tokenize: RobertaTokenizer
619
- self.clap_sample_rate = sample_rate
620
- self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
621
- self.clap_stride = int(self.clap_sample_rate * audio_stride)
622
- self.batch_size = batch_size or 1
623
- self.normalize = normalize
624
- self.text_p = text_p
625
- self.__dict__['clap_tokenize'] = clap_tokenize
626
- self.__dict__['clap'] = clap_model
627
- self.wav_cache, self.text_cache = None, None
628
- if cache_path is not None:
629
- self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
630
- compute_embed_fn=self._get_wav_embedding_for_cache,
631
- extract_embed_fn=self._extract_wav_embedding_chunk)
632
- self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
633
- compute_embed_fn=self._get_text_embedding_for_cache)
634
-
635
- def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
636
- # we use the default params from CLAP module here as well
637
- return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
638
-
639
- def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
640
- """Compute text embedding from CLAP model on a given a batch of text.
641
-
642
- Args:
643
- text (list[str]): List of text for the batch, with B items.
644
- Returns:
645
- torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
646
- """
647
- with torch.no_grad():
648
- embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
649
- return embed.view(embed.size(0), 1, embed.size(-1))
650
-
651
- def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
652
- x: JointEmbedCondition, idx: int) -> torch.Tensor:
653
- """Get text embedding function for the cache."""
654
- text = x.text[idx]
655
- text = text if text is not None else ""
656
- return self._compute_text_embedding([text])[0]
657
-
658
- def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
659
- """Preprocess wav to expected format by CLAP model.
660
-
661
- Args:
662
- wav (torch.Tensor): Audio wav, of shape [B, C, T].
663
- length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
664
- sample_rates (list[int]): Sample rates for each sample in the batch
665
- Returns:
666
- torch.Tensor: Audio wav of shape [B, T].
667
- """
668
- assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
669
- if sample_rates is not None:
670
- _wav = []
671
- for i, audio in enumerate(wav):
672
- sr = sample_rates[i]
673
- audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
674
- _wav.append(audio)
675
- wav = torch.stack(_wav, dim=0)
676
- wav = wav.mean(dim=1)
677
- return wav
678
-
679
- def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
680
- sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
681
- """Compute audio wave embedding from CLAP model.
682
-
683
- Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
684
- we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
685
- average the resulting embeddings.
686
-
687
- Args:
688
- wav (torch.Tensor): Audio wav, of shape [B, C, T].
689
- length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
690
- sample_rates (list[int]): Sample rates for each sample in the batch.
691
- reduce_mean (bool): Whether to get the average tensor.
692
- Returns:
693
- torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
694
- """
695
- with torch.no_grad():
696
- wav = self._preprocess_wav(wav, length, sample_rates)
697
- B, T = wav.shape
698
- if T >= self.clap_max_frames:
699
- wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
700
- else:
701
- wav = wav.view(-1, 1, T) # [B, F, T] with F=1
702
- wav = einops.rearrange(wav, 'b f t -> (b f) t')
703
- embed_list = []
704
- for i in range(0, wav.size(0), self.batch_size):
705
- _wav = wav[i:i+self.batch_size, ...]
706
- _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
707
- embed_list.append(_embed)
708
- embed = torch.cat(embed_list, dim=0)
709
- embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
710
- if reduce_mean:
711
- embed = embed.mean(dim=1, keepdim=True)
712
- return embed # [B, F, D] with F=1 if reduce_mean is True
713
-
714
- def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
715
- x: JointEmbedCondition, idx: int) -> torch.Tensor:
716
- """Compute audio wave embedding for the cache.
717
- The embedding is computed on a given audio read from file.
718
-
719
- Args:
720
- path (str or Path): Path to the full audio file.
721
- Returns:
722
- torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
723
- """
724
- wav, sr = soundfile.read(path) # [C, T]
725
- wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
726
- wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
727
- embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
728
- return embed.squeeze(0) # [F, D]
729
-
730
- def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
731
- """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
732
-
733
- Args:
734
- full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
735
- x (JointEmbedCondition): Joint embedding condition for the full batch.
736
- idx (int): Index considered for the given embedding to extract.
737
- Returns:
738
- torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
739
- """
740
- sample_rate = x.sample_rate[idx]
741
- seek_time = x.seek_time[idx]
742
- seek_time = 0. if seek_time is None else seek_time
743
- clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
744
- end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
745
- start_offset = int(seek_time * sample_rate // clap_stride)
746
- end_offset = int(end_seek_time * sample_rate // clap_stride)
747
- wav_embed = full_embed[start_offset:end_offset, ...]
748
- wav_embed = wav_embed.mean(dim=0, keepdim=True)
749
- return wav_embed.to(self.device) # [F, D]
750
-
751
- def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
752
- """Get CLAP embedding from a batch of text descriptions."""
753
- no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
754
- if self.text_cache is not None and no_nullified_cond:
755
- assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
756
- paths = [Path(p) for p in x.path if p is not None]
757
- embed = self.text_cache.get_embed_from_cache(paths, x)
758
- else:
759
- text = [xi if xi is not None else "" for xi in x.text]
760
- embed = self._compute_text_embedding(text)
761
- if self.normalize:
762
- embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
763
- return embed
764
-
765
- def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
766
- """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
767
- no_undefined_paths = all(p is not None for p in x.path)
768
- no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
769
- if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
770
- paths = [Path(p) for p in x.path if p is not None]
771
- embed = self.wav_cache.get_embed_from_cache(paths, x)
772
- else:
773
- embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
774
- if self.normalize:
775
- embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
776
- return embed
777
-
778
- def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
779
- # Trying to limit as much as possible sync points when the cache is warm.
780
- no_undefined_paths = all(p is not None for p in x.path)
781
- if self.wav_cache is not None and no_undefined_paths:
782
- assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
783
- paths = [Path(p) for p in x.path if p is not None]
784
- self.wav_cache.populate_embed_cache(paths, x)
785
- if self.text_cache is not None and no_undefined_paths:
786
- assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
787
- paths = [Path(p) for p in x.path if p is not None]
788
- self.text_cache.populate_embed_cache(paths, x)
789
- return x
790
-
791
- def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
792
- """Extract shared latent representation from either the wav or the text using CLAP."""
793
- # decide whether to use text embedding at train time or not
794
- use_text_embed = random.random() < self.text_p
795
- if self.training and not use_text_embed:
796
- embed = self._get_wav_embedding(x)
797
- empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
798
- else:
799
- embed = self._get_text_embedding(x)
800
- empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
801
- return embed, empty_idx
802
 
803
 
804
  def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
@@ -938,25 +556,19 @@ class ConditioningProvider(nn.Module):
938
  self.device = device
939
  self.conditioners = nn.ModuleDict(conditioners)
940
 
941
- @property
942
- def joint_embed_conditions(self):
943
- return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
944
 
945
- @property
946
- def has_joint_embed_conditions(self):
947
- return len(self.joint_embed_conditions) > 0
948
 
949
  @property
950
  def text_conditions(self):
951
  return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
952
 
953
- @property
954
- def wav_conditions(self):
955
- return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
956
 
957
- @property
958
- def has_wav_condition(self):
959
- return len(self.wav_conditions) > 0
960
 
961
  def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
962
  """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
@@ -974,15 +586,15 @@ class ConditioningProvider(nn.Module):
974
 
975
  output = {}
976
  text = self._collate_text(inputs)
977
- wavs = self._collate_wavs(inputs)
978
- joint_embeds = self._collate_joint_embeds(inputs)
979
 
980
- assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
981
- f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
982
- f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
983
- )
984
 
985
- for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
986
  output[attribute] = self.conditioners[attribute].tokenize(batch)
987
  return output
988
 
@@ -1031,102 +643,9 @@ class ConditioningProvider(nn.Module):
1031
  out[condition].append(text[condition])
1032
  return out
1033
 
1034
- def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
1035
- """Generate a dict where the keys are attributes by which we fetch similar wavs,
1036
- and the values are Tensors of wavs according to said attributes.
1037
-
1038
- *Note*: by the time the samples reach this function, each sample should have some waveform
1039
- inside the "wav" attribute. It should be either:
1040
- 1. A real waveform
1041
- 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
1042
- 3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
1043
-
1044
- Args:
1045
- samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
1046
- Returns:
1047
- dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
1048
- """
1049
- wavs = defaultdict(list)
1050
- lengths = defaultdict(list)
1051
- sample_rates = defaultdict(list)
1052
- paths = defaultdict(list)
1053
- seek_times = defaultdict(list)
1054
- out: tp.Dict[str, WavCondition] = {}
1055
-
1056
- for sample in samples:
1057
- for attribute in self.wav_conditions:
1058
- wav, length, sample_rate, path, seek_time = sample.wav[attribute]
1059
- assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
1060
- assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
1061
- # mono-channel conditioning
1062
- wav = wav.mean(1, keepdim=True) # [1, 1, T]
1063
- wavs[attribute].append(wav.flatten()) # [T]
1064
- lengths[attribute].append(length)
1065
- sample_rates[attribute].extend(sample_rate)
1066
- paths[attribute].extend(path)
1067
- seek_times[attribute].extend(seek_time)
1068
-
1069
- # stack all wavs to a single tensor
1070
- for attribute in self.wav_conditions:
1071
- stacked_wav, _ = collate(wavs[attribute], dim=0)
1072
- out[attribute] = WavCondition(
1073
- stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
1074
- paths[attribute], seek_times[attribute])
1075
-
1076
- return out
1077
-
1078
- def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
1079
- """Generate a dict where the keys are attributes by which we compute joint embeddings,
1080
- and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
1081
 
1082
- Args:
1083
- samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
1084
- Returns:
1085
- A dictionary mapping an attribute name to joint embeddings.
1086
- """
1087
- texts = defaultdict(list)
1088
- wavs = defaultdict(list)
1089
- lengths = defaultdict(list)
1090
- sample_rates = defaultdict(list)
1091
- paths = defaultdict(list)
1092
- seek_times = defaultdict(list)
1093
- channels: int = 0
1094
-
1095
- out = {}
1096
- for sample in samples:
1097
- for attribute in self.joint_embed_conditions:
1098
- wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
1099
- assert wav.dim() == 3
1100
- if channels == 0:
1101
- channels = wav.size(1)
1102
- else:
1103
- assert channels == wav.size(1), "not all audio has same number of channels in batch"
1104
- assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
1105
- wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
1106
- wavs[attribute].append(wav)
1107
- texts[attribute].extend(text)
1108
- lengths[attribute].append(length)
1109
- sample_rates[attribute].extend(sample_rate)
1110
- paths[attribute].extend(path)
1111
- seek_times[attribute].extend(seek_time)
1112
-
1113
- for attribute in self.joint_embed_conditions:
1114
- stacked_texts = texts[attribute]
1115
- stacked_paths = paths[attribute]
1116
- stacked_seek_times = seek_times[attribute]
1117
- stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
1118
- stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
1119
- stacked_sample_rates = sample_rates[attribute]
1120
- stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
1121
- assert stacked_lengths.size(0) == stacked_wavs.size(0)
1122
- assert len(stacked_sample_rates) == stacked_wavs.size(0)
1123
- assert len(stacked_texts) == stacked_wavs.size(0)
1124
- out[attribute] = JointEmbedCondition(
1125
- text=stacked_texts, wav=stacked_wavs,
1126
- length=stacked_lengths, sample_rate=stacked_sample_rates,
1127
- path=stacked_paths, seek_time=stacked_seek_times)
1128
 
1129
- return out
1130
 
1131
 
1132
  class ConditionFuser(StreamingModule):
 
19
  import einops
20
  from num2words import num2words
21
  import spacy
22
+ from transformers import T5EncoderModel, T5Tokenizer # type: ignore
23
  import torch
24
  from torch import nn
25
  import torch.nn.functional as F
 
317
  ...
318
 
319
 
 
 
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
 
323
  class T5Conditioner(TextConditioner):
 
416
  return embeds, mask
417
 
418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
 
422
  def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
 
556
  self.device = device
557
  self.conditioners = nn.ModuleDict(conditioners)
558
 
559
+ # @property
560
+ # def joint_embed_conditions(self):
561
+ # return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
562
 
563
+ # @property
564
+ # def has_joint_embed_conditions(self):
565
+ # return len(self.joint_embed_conditions) > 0
566
 
567
  @property
568
  def text_conditions(self):
569
  return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
570
 
 
 
 
571
 
 
 
 
572
 
573
  def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
574
  """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
 
586
 
587
  output = {}
588
  text = self._collate_text(inputs)
589
+ # wavs = self._collate_wavs(inputs)
590
+ # joint_embeds = self._collate_joint_embeds(inputs)
591
 
592
+ # assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
593
+ # f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
594
+ # f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
595
+ # )
596
 
597
+ for attribute, batch in text.items(): #, joint_embeds.items()):
598
  output[attribute] = self.conditioners[attribute].tokenize(batch)
599
  return output
600
 
 
643
  out[condition].append(text[condition])
644
  return out
645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
 
648
+
649
 
650
 
651
  class ConditionFuser(StreamingModule):
audiocraft/lm.py CHANGED
@@ -322,39 +322,7 @@ class TextConditioner(BaseConditioner):
322
  ...
323
 
324
 
325
- class LUTConditioner(TextConditioner):
326
- """Lookup table TextConditioner.
327
 
328
- Args:
329
- n_bins (int): Number of bins.
330
- dim (int): Hidden dim of the model (text-encoder/LUT).
331
- output_dim (int): Output dim of the conditioner.
332
- tokenizer (str): Name of the tokenizer.
333
- pad_idx (int, optional): Index for padding token. Defaults to 0.
334
- """
335
- def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
336
- super().__init__(dim, output_dim)
337
- self.embed = nn.Embedding(n_bins, dim)
338
- self.tokenizer: Tokenizer
339
- if tokenizer == 'whitespace':
340
- self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
341
- elif tokenizer == 'noop':
342
- self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
343
- else:
344
- raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
345
-
346
- def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
347
- device = self.embed.weight.device
348
- tokens, mask = self.tokenizer(x)
349
- tokens, mask = tokens.to(device), mask.to(device)
350
- return tokens, mask
351
-
352
- def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
353
- tokens, mask = inputs
354
- embeds = self.embed(tokens)
355
- embeds = self.output_proj(embeds)
356
- embeds = (embeds * mask.unsqueeze(-1))
357
- return embeds, mask
358
 
359
 
360
  class T5Conditioner(TextConditioner):
@@ -453,56 +421,7 @@ class T5Conditioner(TextConditioner):
453
  return embeds, mask
454
 
455
 
456
- class WaveformConditioner(BaseConditioner):
457
- """Base class for all conditioners that take a waveform as input.
458
- Classes that inherit must implement `_get_wav_embedding` that outputs
459
- a continuous tensor, and `_downsampling_factor` that returns the down-sampling
460
- factor of the embedding model.
461
-
462
- Args:
463
- dim (int): The internal representation dimension.
464
- output_dim (int): Output dimension.
465
- device (tp.Union[torch.device, str]): Device.
466
- """
467
- def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
468
- super().__init__(dim, output_dim)
469
- self.device = device
470
- # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
471
- self._use_masking = True
472
-
473
- def tokenize(self, x: WavCondition) -> WavCondition:
474
- wav, length, sample_rate, path, seek_time = x
475
- assert length is not None
476
- return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
477
-
478
- def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
479
- """Gets as input a WavCondition and returns a dense embedding."""
480
- raise NotImplementedError()
481
 
482
- def _downsampling_factor(self):
483
- """Returns the downsampling factor of the embedding model."""
484
- raise NotImplementedError()
485
-
486
- def forward(self, x: WavCondition) -> ConditionType:
487
- """Extract condition embedding and mask from a waveform and its metadata.
488
- Args:
489
- x (WavCondition): Waveform condition containing raw waveform and metadata.
490
- Returns:
491
- ConditionType: a dense vector representing the conditioning along with its mask
492
- """
493
- wav, lengths, *_ = x
494
- with torch.no_grad():
495
- embeds = self._get_wav_embedding(x)
496
- embeds = embeds.to(self.output_proj.weight)
497
- embeds = self.output_proj(embeds)
498
-
499
- if lengths is not None and self._use_masking:
500
- lengths = lengths / self._downsampling_factor()
501
- mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
502
- else:
503
- mask = torch.ones_like(embeds[..., 0])
504
- embeds = (embeds * mask.unsqueeze(-1))
505
- return embeds, mask
506
 
507
 
508
 
@@ -570,366 +489,13 @@ class JointEmbeddingConditioner(BaseConditioner):
570
  return x
571
 
572
 
573
- class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
574
- """Joint Embedding conditioner based on pre-trained CLAP model.
575
-
576
- This CLAP-based conditioner supports a caching mechanism
577
- over the computed embeddings for faster training.
578
 
579
- Args:
580
- dim (int): Dimension.
581
- output_dim (int): Output dimension.
582
- device (str): Device.
583
- attribute (str): Attribute used by the conditioner.
584
- quantize (bool): Whether to quantize the CLAP embedding.
585
- n_q (int): Number of residual quantizers (used if quantize is true).
586
- bins (int): Quantizers' codebooks size (used if quantize is true).
587
- checkpoint (str): Path to CLAP checkpoint.
588
- model_arch (str): CLAP model architecture.
589
- enable_fusion (bool): Enable fusion for CLAP model.
590
- sample_rate (int): Sample rate used by CLAP model.
591
- max_audio_length (float): Maximum audio length for CLAP model.
592
- audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
593
- normalize (bool): Whether to normalize the CLAP embedding.
594
- text_p (float): Probability of using text representation instead of audio at train time.
595
- batch_size (Optional[int]): Batch size for CLAP embedding computation.
596
- autocast_dtype (str): Autocast for the conditioner.
597
- cache_path (Optional[str]): Path for pre-computed embeddings caching.
598
- kwargs: Additional parameters for residual vector quantizer.
599
- """
600
- def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
601
- quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
602
- enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
603
- normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
604
- autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
605
- try:
606
- import laion_clap # type: ignore
607
- except ImportError:
608
- raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
609
- warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
610
- "Please retrain all models.")
611
- checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
612
- clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
613
- clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
614
- load_clap_state_dict(clap_model, checkpoint)
615
- clap_model.eval()
616
- clap_model.to(device)
617
- super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
618
- autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
619
- **kwargs)
620
- self.checkpoint = checkpoint
621
- self.enable_fusion = enable_fusion
622
- self.model_arch = model_arch
623
- self.clap: laion_clap.CLAP_Module
624
- self.clap_tokenize: RobertaTokenizer
625
- self.clap_sample_rate = sample_rate
626
- self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
627
- self.clap_stride = int(self.clap_sample_rate * audio_stride)
628
- self.batch_size = batch_size or 1
629
- self.normalize = normalize
630
- self.text_p = text_p
631
- self.__dict__['clap_tokenize'] = clap_tokenize
632
- self.__dict__['clap'] = clap_model
633
- self.wav_cache, self.text_cache = None, None
634
- if cache_path is not None:
635
- self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
636
- compute_embed_fn=self._get_wav_embedding_for_cache,
637
- extract_embed_fn=self._extract_wav_embedding_chunk)
638
- self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
639
- compute_embed_fn=self._get_text_embedding_for_cache)
640
-
641
- def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
642
- # we use the default params from CLAP module here as well
643
- return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
644
-
645
- def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
646
- """Compute text embedding from CLAP model on a given a batch of text.
647
 
648
- Args:
649
- text (list[str]): List of text for the batch, with B items.
650
- Returns:
651
- torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
652
- """
653
- with torch.no_grad():
654
- embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
655
- return embed.view(embed.size(0), 1, embed.size(-1))
656
-
657
- def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
658
- x: JointEmbedCondition, idx: int) -> torch.Tensor:
659
- """Get text embedding function for the cache."""
660
- text = x.text[idx]
661
- text = text if text is not None else ""
662
- return self._compute_text_embedding([text])[0]
663
 
664
- def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
665
- """Preprocess wav to expected format by CLAP model.
666
-
667
- Args:
668
- wav (torch.Tensor): Audio wav, of shape [B, C, T].
669
- length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
670
- sample_rates (list[int]): Sample rates for each sample in the batch
671
- Returns:
672
- torch.Tensor: Audio wav of shape [B, T].
673
- """
674
- assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
675
- if sample_rates is not None:
676
- _wav = []
677
- for i, audio in enumerate(wav):
678
- sr = sample_rates[i]
679
- audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
680
- _wav.append(audio)
681
- wav = torch.stack(_wav, dim=0)
682
- wav = wav.mean(dim=1)
683
- return wav
684
-
685
- def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
686
- sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
687
- """Compute audio wave embedding from CLAP model.
688
-
689
- Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
690
- we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
691
- average the resulting embeddings.
692
-
693
- Args:
694
- wav (torch.Tensor): Audio wav, of shape [B, C, T].
695
- length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
696
- sample_rates (list[int]): Sample rates for each sample in the batch.
697
- reduce_mean (bool): Whether to get the average tensor.
698
- Returns:
699
- torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
700
- """
701
- with torch.no_grad():
702
- wav = self._preprocess_wav(wav, length, sample_rates)
703
- B, T = wav.shape
704
- if T >= self.clap_max_frames:
705
- wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
706
- else:
707
- wav = wav.view(-1, 1, T) # [B, F, T] with F=1
708
- wav = einops.rearrange(wav, 'b f t -> (b f) t')
709
- embed_list = []
710
- for i in range(0, wav.size(0), self.batch_size):
711
- _wav = wav[i:i+self.batch_size, ...]
712
- _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
713
- embed_list.append(_embed)
714
- embed = torch.cat(embed_list, dim=0)
715
- embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
716
- if reduce_mean:
717
- embed = embed.mean(dim=1, keepdim=True)
718
- return embed # [B, F, D] with F=1 if reduce_mean is True
719
-
720
- def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
721
- x: JointEmbedCondition, idx: int) -> torch.Tensor:
722
- """Compute audio wave embedding for the cache.
723
- The embedding is computed on a given audio read from file.
724
 
725
- Args:
726
- path (str or Path): Path to the full audio file.
727
- Returns:
728
- torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
729
- """
730
- wav, sr = soundfile.read(path) # [C, T]
731
- wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
732
- wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
733
- embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
734
- return embed.squeeze(0) # [F, D]
735
-
736
- def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
737
- """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
738
-
739
- Args:
740
- full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
741
- x (JointEmbedCondition): Joint embedding condition for the full batch.
742
- idx (int): Index considered for the given embedding to extract.
743
- Returns:
744
- torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
745
- """
746
- sample_rate = x.sample_rate[idx]
747
- seek_time = x.seek_time[idx]
748
- seek_time = 0. if seek_time is None else seek_time
749
- clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
750
- end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
751
- start_offset = int(seek_time * sample_rate // clap_stride)
752
- end_offset = int(end_seek_time * sample_rate // clap_stride)
753
- wav_embed = full_embed[start_offset:end_offset, ...]
754
- wav_embed = wav_embed.mean(dim=0, keepdim=True)
755
- return wav_embed.to(self.device) # [F, D]
756
-
757
- def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
758
- """Get CLAP embedding from a batch of text descriptions."""
759
- no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
760
- if self.text_cache is not None and no_nullified_cond:
761
- assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
762
- paths = [Path(p) for p in x.path if p is not None]
763
- embed = self.text_cache.get_embed_from_cache(paths, x)
764
- else:
765
- text = [xi if xi is not None else "" for xi in x.text]
766
- embed = self._compute_text_embedding(text)
767
- if self.normalize:
768
- embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
769
- return embed
770
-
771
- def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
772
- """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
773
- no_undefined_paths = all(p is not None for p in x.path)
774
- no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
775
- if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
776
- paths = [Path(p) for p in x.path if p is not None]
777
- embed = self.wav_cache.get_embed_from_cache(paths, x)
778
- else:
779
- embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
780
- if self.normalize:
781
- embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
782
- return embed
783
-
784
- def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
785
- # Trying to limit as much as possible sync points when the cache is warm.
786
- no_undefined_paths = all(p is not None for p in x.path)
787
- if self.wav_cache is not None and no_undefined_paths:
788
- assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
789
- paths = [Path(p) for p in x.path if p is not None]
790
- self.wav_cache.populate_embed_cache(paths, x)
791
- if self.text_cache is not None and no_undefined_paths:
792
- assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
793
- paths = [Path(p) for p in x.path if p is not None]
794
- self.text_cache.populate_embed_cache(paths, x)
795
- return x
796
-
797
- def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
798
- """Extract shared latent representation from either the wav or the text using CLAP."""
799
- # decide whether to use text embedding at train time or not
800
- use_text_embed = random.random() < self.text_p
801
- if self.training and not use_text_embed:
802
- embed = self._get_wav_embedding(x)
803
- empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
804
- else:
805
- embed = self._get_text_embedding(x)
806
- empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
807
- return embed, empty_idx
808
-
809
-
810
- def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
811
- """Utility function for nullifying an attribute inside an ConditioningAttributes object.
812
- If the condition is of type "wav", then nullify it using `nullify_condition` function.
813
- If the condition is of any other type, set its value to None.
814
- Works in-place.
815
- """
816
- if condition_type not in ['text', 'wav', 'joint_embed']:
817
- raise ValueError(
818
- "dropout_condition got an unexpected condition type!"
819
- f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
820
- )
821
-
822
- if condition not in getattr(sample, condition_type):
823
- raise ValueError(
824
- "dropout_condition received an unexpected condition!"
825
- f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
826
- f" but got '{condition}' of type '{condition_type}'!"
827
- )
828
-
829
- if condition_type == 'wav':
830
- wav_cond = sample.wav[condition]
831
- sample.wav[condition] = nullify_wav(wav_cond)
832
- elif condition_type == 'joint_embed':
833
- embed = sample.joint_embed[condition]
834
- sample.joint_embed[condition] = nullify_joint_embed(embed)
835
- else:
836
- sample.text[condition] = None
837
-
838
- return sample
839
-
840
-
841
- class DropoutModule(nn.Module):
842
- """Base module for all dropout modules."""
843
- def __init__(self, seed: int = 1234):
844
- super().__init__()
845
- self.rng = torch.Generator()
846
- self.rng.manual_seed(seed)
847
-
848
-
849
- class AttributeDropout(DropoutModule):
850
- """Dropout with a given probability per attribute.
851
- This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
852
- to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
853
- This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
854
- must also be dropped.
855
-
856
- Args:
857
- p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
858
- ...
859
- "genre": 0.1,
860
- "artist": 0.5,
861
- "wav": 0.25,
862
- ...
863
- active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
864
- seed (int, optional): Random seed.
865
- """
866
- def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
867
- super().__init__(seed=seed)
868
- self.active_on_eval = active_on_eval
869
- # construct dict that return the values from p otherwise 0
870
- self.p = {}
871
- for condition_type, probs in p.items():
872
- self.p[condition_type] = defaultdict(lambda: 0, probs)
873
-
874
- def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
875
- """
876
- Args:
877
- samples (list[ConditioningAttributes]): List of conditions.
878
- Returns:
879
- list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
880
- """
881
- if not self.training and not self.active_on_eval:
882
- return samples
883
-
884
- samples = deepcopy(samples)
885
- for condition_type, ps in self.p.items(): # for condition types [text, wav]
886
- for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
887
- if torch.rand(1, generator=self.rng).item() < p:
888
- for sample in samples:
889
- dropout_condition(sample, condition_type, condition)
890
- return samples
891
-
892
- def __repr__(self):
893
- return f"AttributeDropout({dict(self.p)})"
894
-
895
-
896
- class ClassifierFreeGuidanceDropout(DropoutModule):
897
- """Classifier Free Guidance dropout.
898
- All attributes are dropped with the same probability.
899
-
900
- Args:
901
- p (float): Probability to apply condition dropout during training.
902
- seed (int): Random seed.
903
- """
904
- def __init__(self, p: float, seed: int = 1234):
905
- super().__init__(seed=seed)
906
- self.p = p
907
-
908
- def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
909
- """
910
- Args:
911
- samples (list[ConditioningAttributes]): List of conditions.
912
- Returns:
913
- list[ConditioningAttributes]: List of conditions after all attributes were set to None.
914
- """
915
- if not self.training:
916
- return samples
917
 
918
- # decide on which attributes to drop in a batched fashion
919
- drop = torch.rand(1, generator=self.rng).item() < self.p
920
- if not drop:
921
- return samples
922
 
923
- # nullify conditions of all attributes
924
- samples = deepcopy(samples)
925
- for condition_type in ["wav", "text"]:
926
- for sample in samples:
927
- for condition in sample.attributes[condition_type]:
928
- dropout_condition(sample, condition_type, condition)
929
- return samples
930
 
931
- def __repr__(self):
932
- return f"ClassifierFreeGuidanceDropout(p={self.p})"
933
 
934
 
935
  class ConditioningProvider(nn.Module):
@@ -1355,8 +921,8 @@ class LMModel(StreamingModule):
1355
  **kwargs):
1356
  super().__init__()
1357
  self.cfg_coef = cfg_coef
1358
- self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
1359
- self.att_dropout = AttributeDropout(p=attribute_dropout)
1360
  self.condition_provider = condition_provider
1361
  self.fuser = fuser
1362
  self.card = card
@@ -1447,10 +1013,7 @@ class LMModel(StreamingModule):
1447
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
1448
  if condition_tensors is None:
1449
  assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
1450
- # apply dropout modules
1451
- conditions = self.cfg_dropout(conditions)
1452
- conditions = self.att_dropout(conditions)
1453
- tokenized = self.condition_provider.tokenize(conditions)
1454
  # encode conditions and fuse, both have a streaming cache to not recompute when generating.
1455
  condition_tensors = self.condition_provider(tokenized)
1456
  else:
@@ -1661,7 +1224,7 @@ class LMModel(StreamingModule):
1661
  cfg_conditions: CFGConditions
1662
  two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
1663
  if conditions:
1664
- null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
1665
  if two_step_cfg:
1666
  cfg_conditions = (
1667
  self.condition_provider(self.condition_provider.tokenize(conditions)),
 
322
  ...
323
 
324
 
 
 
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
 
328
  class T5Conditioner(TextConditioner):
 
421
  return embeds, mask
422
 
423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
 
427
 
 
489
  return x
490
 
491
 
 
 
 
 
 
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
 
 
 
 
497
 
 
 
 
 
 
 
 
498
 
 
 
499
 
500
 
501
  class ConditioningProvider(nn.Module):
 
921
  **kwargs):
922
  super().__init__()
923
  self.cfg_coef = cfg_coef
924
+
925
+
926
  self.condition_provider = condition_provider
927
  self.fuser = fuser
928
  self.card = card
 
1013
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
1014
  if condition_tensors is None:
1015
  assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
1016
+
 
 
 
1017
  # encode conditions and fuse, both have a streaming cache to not recompute when generating.
1018
  condition_tensors = self.condition_provider(tokenized)
1019
  else:
 
1224
  cfg_conditions: CFGConditions
1225
  two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
1226
  if conditions:
1227
+ null_conditions = conditions
1228
  if two_step_cfg:
1229
  cfg_conditions = (
1230
  self.condition_provider(self.condition_provider.tokenize(conditions)),