del xtr funs
Browse files- audiocraft/builders.py +0 -2
- audiocraft/conditioners.py +15 -496
- audiocraft/lm.py +4 -441
audiocraft/builders.py
CHANGED
@@ -28,10 +28,8 @@ from .codebooks_patterns import (
|
|
28 |
)
|
29 |
from .conditioners import (
|
30 |
BaseConditioner,
|
31 |
-
CLAPEmbeddingConditioner,
|
32 |
ConditionFuser,
|
33 |
ConditioningProvider,
|
34 |
-
LUTConditioner,
|
35 |
T5Conditioner,
|
36 |
)
|
37 |
from .unet import DiffusionUnet
|
|
|
28 |
)
|
29 |
from .conditioners import (
|
30 |
BaseConditioner,
|
|
|
31 |
ConditionFuser,
|
32 |
ConditioningProvider,
|
|
|
33 |
T5Conditioner,
|
34 |
)
|
35 |
from .unet import DiffusionUnet
|
audiocraft/conditioners.py
CHANGED
@@ -19,7 +19,7 @@ import soundfile
|
|
19 |
import einops
|
20 |
from num2words import num2words
|
21 |
import spacy
|
22 |
-
from transformers import
|
23 |
import torch
|
24 |
from torch import nn
|
25 |
import torch.nn.functional as F
|
@@ -317,39 +317,7 @@ class TextConditioner(BaseConditioner):
|
|
317 |
...
|
318 |
|
319 |
|
320 |
-
class LUTConditioner(TextConditioner):
|
321 |
-
"""Lookup table TextConditioner.
|
322 |
|
323 |
-
Args:
|
324 |
-
n_bins (int): Number of bins.
|
325 |
-
dim (int): Hidden dim of the model (text-encoder/LUT).
|
326 |
-
output_dim (int): Output dim of the conditioner.
|
327 |
-
tokenizer (str): Name of the tokenizer.
|
328 |
-
pad_idx (int, optional): Index for padding token. Defaults to 0.
|
329 |
-
"""
|
330 |
-
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
|
331 |
-
super().__init__(dim, output_dim)
|
332 |
-
self.embed = nn.Embedding(n_bins, dim)
|
333 |
-
self.tokenizer: Tokenizer
|
334 |
-
if tokenizer == 'whitespace':
|
335 |
-
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
|
336 |
-
elif tokenizer == 'noop':
|
337 |
-
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
|
338 |
-
else:
|
339 |
-
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
|
340 |
-
|
341 |
-
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
342 |
-
device = self.embed.weight.device
|
343 |
-
tokens, mask = self.tokenizer(x)
|
344 |
-
tokens, mask = tokens.to(device), mask.to(device)
|
345 |
-
return tokens, mask
|
346 |
-
|
347 |
-
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
|
348 |
-
tokens, mask = inputs
|
349 |
-
embeds = self.embed(tokens)
|
350 |
-
embeds = self.output_proj(embeds)
|
351 |
-
embeds = (embeds * mask.unsqueeze(-1))
|
352 |
-
return embeds, mask
|
353 |
|
354 |
|
355 |
class T5Conditioner(TextConditioner):
|
@@ -448,357 +416,7 @@ class T5Conditioner(TextConditioner):
|
|
448 |
return embeds, mask
|
449 |
|
450 |
|
451 |
-
class WaveformConditioner(BaseConditioner):
|
452 |
-
"""Base class for all conditioners that take a waveform as input.
|
453 |
-
Classes that inherit must implement `_get_wav_embedding` that outputs
|
454 |
-
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
|
455 |
-
factor of the embedding model.
|
456 |
-
|
457 |
-
Args:
|
458 |
-
dim (int): The internal representation dimension.
|
459 |
-
output_dim (int): Output dimension.
|
460 |
-
device (tp.Union[torch.device, str]): Device.
|
461 |
-
"""
|
462 |
-
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
|
463 |
-
super().__init__(dim, output_dim)
|
464 |
-
self.device = device
|
465 |
-
# if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
|
466 |
-
self._use_masking = True
|
467 |
-
|
468 |
-
def tokenize(self, x: WavCondition) -> WavCondition:
|
469 |
-
wav, length, sample_rate, path, seek_time = x
|
470 |
-
assert length is not None
|
471 |
-
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
|
472 |
-
|
473 |
-
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
474 |
-
"""Gets as input a WavCondition and returns a dense embedding."""
|
475 |
-
raise NotImplementedError()
|
476 |
-
|
477 |
-
def _downsampling_factor(self):
|
478 |
-
"""Returns the downsampling factor of the embedding model."""
|
479 |
-
raise NotImplementedError()
|
480 |
-
|
481 |
-
def forward(self, x: WavCondition) -> ConditionType:
|
482 |
-
"""Extract condition embedding and mask from a waveform and its metadata.
|
483 |
-
Args:
|
484 |
-
x (WavCondition): Waveform condition containing raw waveform and metadata.
|
485 |
-
Returns:
|
486 |
-
ConditionType: a dense vector representing the conditioning along with its mask
|
487 |
-
"""
|
488 |
-
wav, lengths, *_ = x
|
489 |
-
with torch.no_grad():
|
490 |
-
embeds = self._get_wav_embedding(x)
|
491 |
-
embeds = embeds.to(self.output_proj.weight)
|
492 |
-
embeds = self.output_proj(embeds)
|
493 |
-
|
494 |
-
if lengths is not None and self._use_masking:
|
495 |
-
lengths = lengths / self._downsampling_factor()
|
496 |
-
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
497 |
-
else:
|
498 |
-
mask = torch.ones_like(embeds[..., 0])
|
499 |
-
embeds = (embeds * mask.unsqueeze(-1))
|
500 |
-
return embeds, mask
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
|
506 |
-
class JointEmbeddingConditioner(BaseConditioner):
|
507 |
-
"""Joint embedding conditioning supporting both audio or text conditioning.
|
508 |
-
|
509 |
-
Args:
|
510 |
-
dim (int): Dimension.
|
511 |
-
output_dim (int): Output dimension.
|
512 |
-
device (str): Device.
|
513 |
-
attribute (str): Attribute used by the conditioner.
|
514 |
-
autocast_dtype (str): Autocast for the conditioner.
|
515 |
-
quantize (bool): Whether to quantize the CLAP embedding.
|
516 |
-
n_q (int): Number of residual quantizers (used if quantize is true).
|
517 |
-
bins (int): Quantizers' codebooks size (used if quantize is true).
|
518 |
-
kwargs: Additional parameters for residual vector quantizer.
|
519 |
-
"""
|
520 |
-
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
521 |
-
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
|
522 |
-
n_q: int = 12, bins: int = 1024, **kwargs):
|
523 |
-
super().__init__(dim=dim, output_dim=output_dim)
|
524 |
-
self.device = device
|
525 |
-
self.attribute = attribute
|
526 |
-
if autocast_dtype is None or device == 'cpu':
|
527 |
-
self.autocast = TorchAutocast(enabled=False)
|
528 |
-
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
|
529 |
-
else:
|
530 |
-
dtype = getattr(torch, autocast_dtype)
|
531 |
-
assert isinstance(dtype, torch.dtype)
|
532 |
-
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
|
533 |
-
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
534 |
-
# residual vector quantizer to discretize the conditioned embedding
|
535 |
-
self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
|
536 |
-
if quantize:
|
537 |
-
self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
|
538 |
-
|
539 |
-
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
540 |
-
"""Get joint embedding in latent space from the inputs.
|
541 |
-
|
542 |
-
Returns:
|
543 |
-
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
|
544 |
-
and corresponding empty indexes.
|
545 |
-
"""
|
546 |
-
raise NotImplementedError()
|
547 |
-
|
548 |
-
def forward(self, x: JointEmbedCondition) -> ConditionType:
|
549 |
-
with self.autocast:
|
550 |
-
embed, empty_idx = self._get_embed(x)
|
551 |
-
if self.quantizer is not None:
|
552 |
-
embed = embed.view(-1, self.dim, 1)
|
553 |
-
q_res = self.quantizer(embed, frame_rate=1)
|
554 |
-
out_embed = q_res.x.view(-1, self.dim)
|
555 |
-
else:
|
556 |
-
out_embed = embed
|
557 |
-
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
|
558 |
-
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
|
559 |
-
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
560 |
-
out_embed = (out_embed * mask.unsqueeze(-1))
|
561 |
-
return out_embed, mask
|
562 |
-
|
563 |
-
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
564 |
-
return x
|
565 |
-
|
566 |
-
|
567 |
-
class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
|
568 |
-
"""Joint Embedding conditioner based on pre-trained CLAP model.
|
569 |
-
|
570 |
-
This CLAP-based conditioner supports a caching mechanism
|
571 |
-
over the computed embeddings for faster training.
|
572 |
-
|
573 |
-
Args:
|
574 |
-
dim (int): Dimension.
|
575 |
-
output_dim (int): Output dimension.
|
576 |
-
device (str): Device.
|
577 |
-
attribute (str): Attribute used by the conditioner.
|
578 |
-
quantize (bool): Whether to quantize the CLAP embedding.
|
579 |
-
n_q (int): Number of residual quantizers (used if quantize is true).
|
580 |
-
bins (int): Quantizers' codebooks size (used if quantize is true).
|
581 |
-
checkpoint (str): Path to CLAP checkpoint.
|
582 |
-
model_arch (str): CLAP model architecture.
|
583 |
-
enable_fusion (bool): Enable fusion for CLAP model.
|
584 |
-
sample_rate (int): Sample rate used by CLAP model.
|
585 |
-
max_audio_length (float): Maximum audio length for CLAP model.
|
586 |
-
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
|
587 |
-
normalize (bool): Whether to normalize the CLAP embedding.
|
588 |
-
text_p (float): Probability of using text representation instead of audio at train time.
|
589 |
-
batch_size (Optional[int]): Batch size for CLAP embedding computation.
|
590 |
-
autocast_dtype (str): Autocast for the conditioner.
|
591 |
-
cache_path (Optional[str]): Path for pre-computed embeddings caching.
|
592 |
-
kwargs: Additional parameters for residual vector quantizer.
|
593 |
-
"""
|
594 |
-
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
595 |
-
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
|
596 |
-
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
|
597 |
-
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
|
598 |
-
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
|
599 |
-
try:
|
600 |
-
import laion_clap # type: ignore
|
601 |
-
except ImportError:
|
602 |
-
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
|
603 |
-
warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
|
604 |
-
"Please retrain all models.")
|
605 |
-
# checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
|
606 |
-
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
607 |
-
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
608 |
-
load_clap_state_dict(clap_model, checkpoint)
|
609 |
-
clap_model.eval()
|
610 |
-
clap_model.to(device)
|
611 |
-
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
|
612 |
-
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
|
613 |
-
**kwargs)
|
614 |
-
self.checkpoint = checkpoint
|
615 |
-
self.enable_fusion = enable_fusion
|
616 |
-
self.model_arch = model_arch
|
617 |
-
self.clap: laion_clap.CLAP_Module
|
618 |
-
self.clap_tokenize: RobertaTokenizer
|
619 |
-
self.clap_sample_rate = sample_rate
|
620 |
-
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
|
621 |
-
self.clap_stride = int(self.clap_sample_rate * audio_stride)
|
622 |
-
self.batch_size = batch_size or 1
|
623 |
-
self.normalize = normalize
|
624 |
-
self.text_p = text_p
|
625 |
-
self.__dict__['clap_tokenize'] = clap_tokenize
|
626 |
-
self.__dict__['clap'] = clap_model
|
627 |
-
self.wav_cache, self.text_cache = None, None
|
628 |
-
if cache_path is not None:
|
629 |
-
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
630 |
-
compute_embed_fn=self._get_wav_embedding_for_cache,
|
631 |
-
extract_embed_fn=self._extract_wav_embedding_chunk)
|
632 |
-
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
|
633 |
-
compute_embed_fn=self._get_text_embedding_for_cache)
|
634 |
-
|
635 |
-
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
636 |
-
# we use the default params from CLAP module here as well
|
637 |
-
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
638 |
-
|
639 |
-
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
|
640 |
-
"""Compute text embedding from CLAP model on a given a batch of text.
|
641 |
-
|
642 |
-
Args:
|
643 |
-
text (list[str]): List of text for the batch, with B items.
|
644 |
-
Returns:
|
645 |
-
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
|
646 |
-
"""
|
647 |
-
with torch.no_grad():
|
648 |
-
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
649 |
-
return embed.view(embed.size(0), 1, embed.size(-1))
|
650 |
-
|
651 |
-
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
|
652 |
-
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
653 |
-
"""Get text embedding function for the cache."""
|
654 |
-
text = x.text[idx]
|
655 |
-
text = text if text is not None else ""
|
656 |
-
return self._compute_text_embedding([text])[0]
|
657 |
-
|
658 |
-
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
|
659 |
-
"""Preprocess wav to expected format by CLAP model.
|
660 |
-
|
661 |
-
Args:
|
662 |
-
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
663 |
-
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
664 |
-
sample_rates (list[int]): Sample rates for each sample in the batch
|
665 |
-
Returns:
|
666 |
-
torch.Tensor: Audio wav of shape [B, T].
|
667 |
-
"""
|
668 |
-
assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
|
669 |
-
if sample_rates is not None:
|
670 |
-
_wav = []
|
671 |
-
for i, audio in enumerate(wav):
|
672 |
-
sr = sample_rates[i]
|
673 |
-
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
|
674 |
-
_wav.append(audio)
|
675 |
-
wav = torch.stack(_wav, dim=0)
|
676 |
-
wav = wav.mean(dim=1)
|
677 |
-
return wav
|
678 |
-
|
679 |
-
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
|
680 |
-
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
|
681 |
-
"""Compute audio wave embedding from CLAP model.
|
682 |
-
|
683 |
-
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
|
684 |
-
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
|
685 |
-
average the resulting embeddings.
|
686 |
-
|
687 |
-
Args:
|
688 |
-
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
689 |
-
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
690 |
-
sample_rates (list[int]): Sample rates for each sample in the batch.
|
691 |
-
reduce_mean (bool): Whether to get the average tensor.
|
692 |
-
Returns:
|
693 |
-
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
|
694 |
-
"""
|
695 |
-
with torch.no_grad():
|
696 |
-
wav = self._preprocess_wav(wav, length, sample_rates)
|
697 |
-
B, T = wav.shape
|
698 |
-
if T >= self.clap_max_frames:
|
699 |
-
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
|
700 |
-
else:
|
701 |
-
wav = wav.view(-1, 1, T) # [B, F, T] with F=1
|
702 |
-
wav = einops.rearrange(wav, 'b f t -> (b f) t')
|
703 |
-
embed_list = []
|
704 |
-
for i in range(0, wav.size(0), self.batch_size):
|
705 |
-
_wav = wav[i:i+self.batch_size, ...]
|
706 |
-
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
|
707 |
-
embed_list.append(_embed)
|
708 |
-
embed = torch.cat(embed_list, dim=0)
|
709 |
-
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
|
710 |
-
if reduce_mean:
|
711 |
-
embed = embed.mean(dim=1, keepdim=True)
|
712 |
-
return embed # [B, F, D] with F=1 if reduce_mean is True
|
713 |
-
|
714 |
-
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
|
715 |
-
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
716 |
-
"""Compute audio wave embedding for the cache.
|
717 |
-
The embedding is computed on a given audio read from file.
|
718 |
-
|
719 |
-
Args:
|
720 |
-
path (str or Path): Path to the full audio file.
|
721 |
-
Returns:
|
722 |
-
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
|
723 |
-
"""
|
724 |
-
wav, sr = soundfile.read(path) # [C, T]
|
725 |
-
wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
|
726 |
-
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
|
727 |
-
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
|
728 |
-
return embed.squeeze(0) # [F, D]
|
729 |
-
|
730 |
-
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
731 |
-
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
|
732 |
-
|
733 |
-
Args:
|
734 |
-
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
|
735 |
-
x (JointEmbedCondition): Joint embedding condition for the full batch.
|
736 |
-
idx (int): Index considered for the given embedding to extract.
|
737 |
-
Returns:
|
738 |
-
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
|
739 |
-
"""
|
740 |
-
sample_rate = x.sample_rate[idx]
|
741 |
-
seek_time = x.seek_time[idx]
|
742 |
-
seek_time = 0. if seek_time is None else seek_time
|
743 |
-
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
|
744 |
-
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
|
745 |
-
start_offset = int(seek_time * sample_rate // clap_stride)
|
746 |
-
end_offset = int(end_seek_time * sample_rate // clap_stride)
|
747 |
-
wav_embed = full_embed[start_offset:end_offset, ...]
|
748 |
-
wav_embed = wav_embed.mean(dim=0, keepdim=True)
|
749 |
-
return wav_embed.to(self.device) # [F, D]
|
750 |
-
|
751 |
-
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
752 |
-
"""Get CLAP embedding from a batch of text descriptions."""
|
753 |
-
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
754 |
-
if self.text_cache is not None and no_nullified_cond:
|
755 |
-
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
|
756 |
-
paths = [Path(p) for p in x.path if p is not None]
|
757 |
-
embed = self.text_cache.get_embed_from_cache(paths, x)
|
758 |
-
else:
|
759 |
-
text = [xi if xi is not None else "" for xi in x.text]
|
760 |
-
embed = self._compute_text_embedding(text)
|
761 |
-
if self.normalize:
|
762 |
-
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
763 |
-
return embed
|
764 |
-
|
765 |
-
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
766 |
-
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
|
767 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
768 |
-
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
769 |
-
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
|
770 |
-
paths = [Path(p) for p in x.path if p is not None]
|
771 |
-
embed = self.wav_cache.get_embed_from_cache(paths, x)
|
772 |
-
else:
|
773 |
-
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
|
774 |
-
if self.normalize:
|
775 |
-
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
776 |
-
return embed
|
777 |
-
|
778 |
-
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
779 |
-
# Trying to limit as much as possible sync points when the cache is warm.
|
780 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
781 |
-
if self.wav_cache is not None and no_undefined_paths:
|
782 |
-
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
783 |
-
paths = [Path(p) for p in x.path if p is not None]
|
784 |
-
self.wav_cache.populate_embed_cache(paths, x)
|
785 |
-
if self.text_cache is not None and no_undefined_paths:
|
786 |
-
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
787 |
-
paths = [Path(p) for p in x.path if p is not None]
|
788 |
-
self.text_cache.populate_embed_cache(paths, x)
|
789 |
-
return x
|
790 |
-
|
791 |
-
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
792 |
-
"""Extract shared latent representation from either the wav or the text using CLAP."""
|
793 |
-
# decide whether to use text embedding at train time or not
|
794 |
-
use_text_embed = random.random() < self.text_p
|
795 |
-
if self.training and not use_text_embed:
|
796 |
-
embed = self._get_wav_embedding(x)
|
797 |
-
empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
|
798 |
-
else:
|
799 |
-
embed = self._get_text_embedding(x)
|
800 |
-
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
|
801 |
-
return embed, empty_idx
|
802 |
|
803 |
|
804 |
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
|
@@ -938,25 +556,19 @@ class ConditioningProvider(nn.Module):
|
|
938 |
self.device = device
|
939 |
self.conditioners = nn.ModuleDict(conditioners)
|
940 |
|
941 |
-
@property
|
942 |
-
def joint_embed_conditions(self):
|
943 |
-
|
944 |
|
945 |
-
@property
|
946 |
-
def has_joint_embed_conditions(self):
|
947 |
-
|
948 |
|
949 |
@property
|
950 |
def text_conditions(self):
|
951 |
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
952 |
|
953 |
-
@property
|
954 |
-
def wav_conditions(self):
|
955 |
-
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
|
956 |
|
957 |
-
@property
|
958 |
-
def has_wav_condition(self):
|
959 |
-
return len(self.wav_conditions) > 0
|
960 |
|
961 |
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
962 |
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
@@ -974,15 +586,15 @@ class ConditioningProvider(nn.Module):
|
|
974 |
|
975 |
output = {}
|
976 |
text = self._collate_text(inputs)
|
977 |
-
wavs = self._collate_wavs(inputs)
|
978 |
-
joint_embeds = self._collate_joint_embeds(inputs)
|
979 |
|
980 |
-
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
981 |
-
|
982 |
-
|
983 |
-
)
|
984 |
|
985 |
-
for attribute, batch in
|
986 |
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
987 |
return output
|
988 |
|
@@ -1031,102 +643,9 @@ class ConditioningProvider(nn.Module):
|
|
1031 |
out[condition].append(text[condition])
|
1032 |
return out
|
1033 |
|
1034 |
-
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
|
1035 |
-
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
1036 |
-
and the values are Tensors of wavs according to said attributes.
|
1037 |
-
|
1038 |
-
*Note*: by the time the samples reach this function, each sample should have some waveform
|
1039 |
-
inside the "wav" attribute. It should be either:
|
1040 |
-
1. A real waveform
|
1041 |
-
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
|
1042 |
-
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
1043 |
-
|
1044 |
-
Args:
|
1045 |
-
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
1046 |
-
Returns:
|
1047 |
-
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
|
1048 |
-
"""
|
1049 |
-
wavs = defaultdict(list)
|
1050 |
-
lengths = defaultdict(list)
|
1051 |
-
sample_rates = defaultdict(list)
|
1052 |
-
paths = defaultdict(list)
|
1053 |
-
seek_times = defaultdict(list)
|
1054 |
-
out: tp.Dict[str, WavCondition] = {}
|
1055 |
-
|
1056 |
-
for sample in samples:
|
1057 |
-
for attribute in self.wav_conditions:
|
1058 |
-
wav, length, sample_rate, path, seek_time = sample.wav[attribute]
|
1059 |
-
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
|
1060 |
-
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
|
1061 |
-
# mono-channel conditioning
|
1062 |
-
wav = wav.mean(1, keepdim=True) # [1, 1, T]
|
1063 |
-
wavs[attribute].append(wav.flatten()) # [T]
|
1064 |
-
lengths[attribute].append(length)
|
1065 |
-
sample_rates[attribute].extend(sample_rate)
|
1066 |
-
paths[attribute].extend(path)
|
1067 |
-
seek_times[attribute].extend(seek_time)
|
1068 |
-
|
1069 |
-
# stack all wavs to a single tensor
|
1070 |
-
for attribute in self.wav_conditions:
|
1071 |
-
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
1072 |
-
out[attribute] = WavCondition(
|
1073 |
-
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
|
1074 |
-
paths[attribute], seek_times[attribute])
|
1075 |
-
|
1076 |
-
return out
|
1077 |
-
|
1078 |
-
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
|
1079 |
-
"""Generate a dict where the keys are attributes by which we compute joint embeddings,
|
1080 |
-
and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
|
1081 |
|
1082 |
-
Args:
|
1083 |
-
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
|
1084 |
-
Returns:
|
1085 |
-
A dictionary mapping an attribute name to joint embeddings.
|
1086 |
-
"""
|
1087 |
-
texts = defaultdict(list)
|
1088 |
-
wavs = defaultdict(list)
|
1089 |
-
lengths = defaultdict(list)
|
1090 |
-
sample_rates = defaultdict(list)
|
1091 |
-
paths = defaultdict(list)
|
1092 |
-
seek_times = defaultdict(list)
|
1093 |
-
channels: int = 0
|
1094 |
-
|
1095 |
-
out = {}
|
1096 |
-
for sample in samples:
|
1097 |
-
for attribute in self.joint_embed_conditions:
|
1098 |
-
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
|
1099 |
-
assert wav.dim() == 3
|
1100 |
-
if channels == 0:
|
1101 |
-
channels = wav.size(1)
|
1102 |
-
else:
|
1103 |
-
assert channels == wav.size(1), "not all audio has same number of channels in batch"
|
1104 |
-
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
|
1105 |
-
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
|
1106 |
-
wavs[attribute].append(wav)
|
1107 |
-
texts[attribute].extend(text)
|
1108 |
-
lengths[attribute].append(length)
|
1109 |
-
sample_rates[attribute].extend(sample_rate)
|
1110 |
-
paths[attribute].extend(path)
|
1111 |
-
seek_times[attribute].extend(seek_time)
|
1112 |
-
|
1113 |
-
for attribute in self.joint_embed_conditions:
|
1114 |
-
stacked_texts = texts[attribute]
|
1115 |
-
stacked_paths = paths[attribute]
|
1116 |
-
stacked_seek_times = seek_times[attribute]
|
1117 |
-
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
|
1118 |
-
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
|
1119 |
-
stacked_sample_rates = sample_rates[attribute]
|
1120 |
-
stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
|
1121 |
-
assert stacked_lengths.size(0) == stacked_wavs.size(0)
|
1122 |
-
assert len(stacked_sample_rates) == stacked_wavs.size(0)
|
1123 |
-
assert len(stacked_texts) == stacked_wavs.size(0)
|
1124 |
-
out[attribute] = JointEmbedCondition(
|
1125 |
-
text=stacked_texts, wav=stacked_wavs,
|
1126 |
-
length=stacked_lengths, sample_rate=stacked_sample_rates,
|
1127 |
-
path=stacked_paths, seek_time=stacked_seek_times)
|
1128 |
|
1129 |
-
|
1130 |
|
1131 |
|
1132 |
class ConditionFuser(StreamingModule):
|
|
|
19 |
import einops
|
20 |
from num2words import num2words
|
21 |
import spacy
|
22 |
+
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
23 |
import torch
|
24 |
from torch import nn
|
25 |
import torch.nn.functional as F
|
|
|
317 |
...
|
318 |
|
319 |
|
|
|
|
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
|
323 |
class T5Conditioner(TextConditioner):
|
|
|
416 |
return embeds, mask
|
417 |
|
418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
|
422 |
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
|
|
|
556 |
self.device = device
|
557 |
self.conditioners = nn.ModuleDict(conditioners)
|
558 |
|
559 |
+
# @property
|
560 |
+
# def joint_embed_conditions(self):
|
561 |
+
# return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
|
562 |
|
563 |
+
# @property
|
564 |
+
# def has_joint_embed_conditions(self):
|
565 |
+
# return len(self.joint_embed_conditions) > 0
|
566 |
|
567 |
@property
|
568 |
def text_conditions(self):
|
569 |
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
570 |
|
|
|
|
|
|
|
571 |
|
|
|
|
|
|
|
572 |
|
573 |
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
574 |
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
|
|
586 |
|
587 |
output = {}
|
588 |
text = self._collate_text(inputs)
|
589 |
+
# wavs = self._collate_wavs(inputs)
|
590 |
+
# joint_embeds = self._collate_joint_embeds(inputs)
|
591 |
|
592 |
+
# assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
593 |
+
# f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
594 |
+
# f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
595 |
+
# )
|
596 |
|
597 |
+
for attribute, batch in text.items(): #, joint_embeds.items()):
|
598 |
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
599 |
return output
|
600 |
|
|
|
643 |
out[condition].append(text[condition])
|
644 |
return out
|
645 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
|
648 |
+
|
649 |
|
650 |
|
651 |
class ConditionFuser(StreamingModule):
|
audiocraft/lm.py
CHANGED
@@ -322,39 +322,7 @@ class TextConditioner(BaseConditioner):
|
|
322 |
...
|
323 |
|
324 |
|
325 |
-
class LUTConditioner(TextConditioner):
|
326 |
-
"""Lookup table TextConditioner.
|
327 |
|
328 |
-
Args:
|
329 |
-
n_bins (int): Number of bins.
|
330 |
-
dim (int): Hidden dim of the model (text-encoder/LUT).
|
331 |
-
output_dim (int): Output dim of the conditioner.
|
332 |
-
tokenizer (str): Name of the tokenizer.
|
333 |
-
pad_idx (int, optional): Index for padding token. Defaults to 0.
|
334 |
-
"""
|
335 |
-
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
|
336 |
-
super().__init__(dim, output_dim)
|
337 |
-
self.embed = nn.Embedding(n_bins, dim)
|
338 |
-
self.tokenizer: Tokenizer
|
339 |
-
if tokenizer == 'whitespace':
|
340 |
-
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
|
341 |
-
elif tokenizer == 'noop':
|
342 |
-
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
|
343 |
-
else:
|
344 |
-
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
|
345 |
-
|
346 |
-
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
347 |
-
device = self.embed.weight.device
|
348 |
-
tokens, mask = self.tokenizer(x)
|
349 |
-
tokens, mask = tokens.to(device), mask.to(device)
|
350 |
-
return tokens, mask
|
351 |
-
|
352 |
-
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
|
353 |
-
tokens, mask = inputs
|
354 |
-
embeds = self.embed(tokens)
|
355 |
-
embeds = self.output_proj(embeds)
|
356 |
-
embeds = (embeds * mask.unsqueeze(-1))
|
357 |
-
return embeds, mask
|
358 |
|
359 |
|
360 |
class T5Conditioner(TextConditioner):
|
@@ -453,56 +421,7 @@ class T5Conditioner(TextConditioner):
|
|
453 |
return embeds, mask
|
454 |
|
455 |
|
456 |
-
class WaveformConditioner(BaseConditioner):
|
457 |
-
"""Base class for all conditioners that take a waveform as input.
|
458 |
-
Classes that inherit must implement `_get_wav_embedding` that outputs
|
459 |
-
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
|
460 |
-
factor of the embedding model.
|
461 |
-
|
462 |
-
Args:
|
463 |
-
dim (int): The internal representation dimension.
|
464 |
-
output_dim (int): Output dimension.
|
465 |
-
device (tp.Union[torch.device, str]): Device.
|
466 |
-
"""
|
467 |
-
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
|
468 |
-
super().__init__(dim, output_dim)
|
469 |
-
self.device = device
|
470 |
-
# if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
|
471 |
-
self._use_masking = True
|
472 |
-
|
473 |
-
def tokenize(self, x: WavCondition) -> WavCondition:
|
474 |
-
wav, length, sample_rate, path, seek_time = x
|
475 |
-
assert length is not None
|
476 |
-
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
|
477 |
-
|
478 |
-
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
479 |
-
"""Gets as input a WavCondition and returns a dense embedding."""
|
480 |
-
raise NotImplementedError()
|
481 |
|
482 |
-
def _downsampling_factor(self):
|
483 |
-
"""Returns the downsampling factor of the embedding model."""
|
484 |
-
raise NotImplementedError()
|
485 |
-
|
486 |
-
def forward(self, x: WavCondition) -> ConditionType:
|
487 |
-
"""Extract condition embedding and mask from a waveform and its metadata.
|
488 |
-
Args:
|
489 |
-
x (WavCondition): Waveform condition containing raw waveform and metadata.
|
490 |
-
Returns:
|
491 |
-
ConditionType: a dense vector representing the conditioning along with its mask
|
492 |
-
"""
|
493 |
-
wav, lengths, *_ = x
|
494 |
-
with torch.no_grad():
|
495 |
-
embeds = self._get_wav_embedding(x)
|
496 |
-
embeds = embeds.to(self.output_proj.weight)
|
497 |
-
embeds = self.output_proj(embeds)
|
498 |
-
|
499 |
-
if lengths is not None and self._use_masking:
|
500 |
-
lengths = lengths / self._downsampling_factor()
|
501 |
-
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
502 |
-
else:
|
503 |
-
mask = torch.ones_like(embeds[..., 0])
|
504 |
-
embeds = (embeds * mask.unsqueeze(-1))
|
505 |
-
return embeds, mask
|
506 |
|
507 |
|
508 |
|
@@ -570,366 +489,13 @@ class JointEmbeddingConditioner(BaseConditioner):
|
|
570 |
return x
|
571 |
|
572 |
|
573 |
-
class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
|
574 |
-
"""Joint Embedding conditioner based on pre-trained CLAP model.
|
575 |
-
|
576 |
-
This CLAP-based conditioner supports a caching mechanism
|
577 |
-
over the computed embeddings for faster training.
|
578 |
|
579 |
-
Args:
|
580 |
-
dim (int): Dimension.
|
581 |
-
output_dim (int): Output dimension.
|
582 |
-
device (str): Device.
|
583 |
-
attribute (str): Attribute used by the conditioner.
|
584 |
-
quantize (bool): Whether to quantize the CLAP embedding.
|
585 |
-
n_q (int): Number of residual quantizers (used if quantize is true).
|
586 |
-
bins (int): Quantizers' codebooks size (used if quantize is true).
|
587 |
-
checkpoint (str): Path to CLAP checkpoint.
|
588 |
-
model_arch (str): CLAP model architecture.
|
589 |
-
enable_fusion (bool): Enable fusion for CLAP model.
|
590 |
-
sample_rate (int): Sample rate used by CLAP model.
|
591 |
-
max_audio_length (float): Maximum audio length for CLAP model.
|
592 |
-
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
|
593 |
-
normalize (bool): Whether to normalize the CLAP embedding.
|
594 |
-
text_p (float): Probability of using text representation instead of audio at train time.
|
595 |
-
batch_size (Optional[int]): Batch size for CLAP embedding computation.
|
596 |
-
autocast_dtype (str): Autocast for the conditioner.
|
597 |
-
cache_path (Optional[str]): Path for pre-computed embeddings caching.
|
598 |
-
kwargs: Additional parameters for residual vector quantizer.
|
599 |
-
"""
|
600 |
-
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
601 |
-
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
|
602 |
-
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
|
603 |
-
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
|
604 |
-
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
|
605 |
-
try:
|
606 |
-
import laion_clap # type: ignore
|
607 |
-
except ImportError:
|
608 |
-
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
|
609 |
-
warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
|
610 |
-
"Please retrain all models.")
|
611 |
-
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
|
612 |
-
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
613 |
-
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
614 |
-
load_clap_state_dict(clap_model, checkpoint)
|
615 |
-
clap_model.eval()
|
616 |
-
clap_model.to(device)
|
617 |
-
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
|
618 |
-
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
|
619 |
-
**kwargs)
|
620 |
-
self.checkpoint = checkpoint
|
621 |
-
self.enable_fusion = enable_fusion
|
622 |
-
self.model_arch = model_arch
|
623 |
-
self.clap: laion_clap.CLAP_Module
|
624 |
-
self.clap_tokenize: RobertaTokenizer
|
625 |
-
self.clap_sample_rate = sample_rate
|
626 |
-
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
|
627 |
-
self.clap_stride = int(self.clap_sample_rate * audio_stride)
|
628 |
-
self.batch_size = batch_size or 1
|
629 |
-
self.normalize = normalize
|
630 |
-
self.text_p = text_p
|
631 |
-
self.__dict__['clap_tokenize'] = clap_tokenize
|
632 |
-
self.__dict__['clap'] = clap_model
|
633 |
-
self.wav_cache, self.text_cache = None, None
|
634 |
-
if cache_path is not None:
|
635 |
-
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
636 |
-
compute_embed_fn=self._get_wav_embedding_for_cache,
|
637 |
-
extract_embed_fn=self._extract_wav_embedding_chunk)
|
638 |
-
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
|
639 |
-
compute_embed_fn=self._get_text_embedding_for_cache)
|
640 |
-
|
641 |
-
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
642 |
-
# we use the default params from CLAP module here as well
|
643 |
-
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
644 |
-
|
645 |
-
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
|
646 |
-
"""Compute text embedding from CLAP model on a given a batch of text.
|
647 |
|
648 |
-
Args:
|
649 |
-
text (list[str]): List of text for the batch, with B items.
|
650 |
-
Returns:
|
651 |
-
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
|
652 |
-
"""
|
653 |
-
with torch.no_grad():
|
654 |
-
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
655 |
-
return embed.view(embed.size(0), 1, embed.size(-1))
|
656 |
-
|
657 |
-
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
|
658 |
-
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
659 |
-
"""Get text embedding function for the cache."""
|
660 |
-
text = x.text[idx]
|
661 |
-
text = text if text is not None else ""
|
662 |
-
return self._compute_text_embedding([text])[0]
|
663 |
|
664 |
-
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
|
665 |
-
"""Preprocess wav to expected format by CLAP model.
|
666 |
-
|
667 |
-
Args:
|
668 |
-
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
669 |
-
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
670 |
-
sample_rates (list[int]): Sample rates for each sample in the batch
|
671 |
-
Returns:
|
672 |
-
torch.Tensor: Audio wav of shape [B, T].
|
673 |
-
"""
|
674 |
-
assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
|
675 |
-
if sample_rates is not None:
|
676 |
-
_wav = []
|
677 |
-
for i, audio in enumerate(wav):
|
678 |
-
sr = sample_rates[i]
|
679 |
-
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
|
680 |
-
_wav.append(audio)
|
681 |
-
wav = torch.stack(_wav, dim=0)
|
682 |
-
wav = wav.mean(dim=1)
|
683 |
-
return wav
|
684 |
-
|
685 |
-
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
|
686 |
-
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
|
687 |
-
"""Compute audio wave embedding from CLAP model.
|
688 |
-
|
689 |
-
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
|
690 |
-
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
|
691 |
-
average the resulting embeddings.
|
692 |
-
|
693 |
-
Args:
|
694 |
-
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
695 |
-
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
696 |
-
sample_rates (list[int]): Sample rates for each sample in the batch.
|
697 |
-
reduce_mean (bool): Whether to get the average tensor.
|
698 |
-
Returns:
|
699 |
-
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
|
700 |
-
"""
|
701 |
-
with torch.no_grad():
|
702 |
-
wav = self._preprocess_wav(wav, length, sample_rates)
|
703 |
-
B, T = wav.shape
|
704 |
-
if T >= self.clap_max_frames:
|
705 |
-
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
|
706 |
-
else:
|
707 |
-
wav = wav.view(-1, 1, T) # [B, F, T] with F=1
|
708 |
-
wav = einops.rearrange(wav, 'b f t -> (b f) t')
|
709 |
-
embed_list = []
|
710 |
-
for i in range(0, wav.size(0), self.batch_size):
|
711 |
-
_wav = wav[i:i+self.batch_size, ...]
|
712 |
-
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
|
713 |
-
embed_list.append(_embed)
|
714 |
-
embed = torch.cat(embed_list, dim=0)
|
715 |
-
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
|
716 |
-
if reduce_mean:
|
717 |
-
embed = embed.mean(dim=1, keepdim=True)
|
718 |
-
return embed # [B, F, D] with F=1 if reduce_mean is True
|
719 |
-
|
720 |
-
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
|
721 |
-
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
722 |
-
"""Compute audio wave embedding for the cache.
|
723 |
-
The embedding is computed on a given audio read from file.
|
724 |
|
725 |
-
Args:
|
726 |
-
path (str or Path): Path to the full audio file.
|
727 |
-
Returns:
|
728 |
-
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
|
729 |
-
"""
|
730 |
-
wav, sr = soundfile.read(path) # [C, T]
|
731 |
-
wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
|
732 |
-
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
|
733 |
-
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
|
734 |
-
return embed.squeeze(0) # [F, D]
|
735 |
-
|
736 |
-
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
737 |
-
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
|
738 |
-
|
739 |
-
Args:
|
740 |
-
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
|
741 |
-
x (JointEmbedCondition): Joint embedding condition for the full batch.
|
742 |
-
idx (int): Index considered for the given embedding to extract.
|
743 |
-
Returns:
|
744 |
-
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
|
745 |
-
"""
|
746 |
-
sample_rate = x.sample_rate[idx]
|
747 |
-
seek_time = x.seek_time[idx]
|
748 |
-
seek_time = 0. if seek_time is None else seek_time
|
749 |
-
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
|
750 |
-
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
|
751 |
-
start_offset = int(seek_time * sample_rate // clap_stride)
|
752 |
-
end_offset = int(end_seek_time * sample_rate // clap_stride)
|
753 |
-
wav_embed = full_embed[start_offset:end_offset, ...]
|
754 |
-
wav_embed = wav_embed.mean(dim=0, keepdim=True)
|
755 |
-
return wav_embed.to(self.device) # [F, D]
|
756 |
-
|
757 |
-
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
758 |
-
"""Get CLAP embedding from a batch of text descriptions."""
|
759 |
-
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
760 |
-
if self.text_cache is not None and no_nullified_cond:
|
761 |
-
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
|
762 |
-
paths = [Path(p) for p in x.path if p is not None]
|
763 |
-
embed = self.text_cache.get_embed_from_cache(paths, x)
|
764 |
-
else:
|
765 |
-
text = [xi if xi is not None else "" for xi in x.text]
|
766 |
-
embed = self._compute_text_embedding(text)
|
767 |
-
if self.normalize:
|
768 |
-
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
769 |
-
return embed
|
770 |
-
|
771 |
-
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
772 |
-
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
|
773 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
774 |
-
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
775 |
-
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
|
776 |
-
paths = [Path(p) for p in x.path if p is not None]
|
777 |
-
embed = self.wav_cache.get_embed_from_cache(paths, x)
|
778 |
-
else:
|
779 |
-
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
|
780 |
-
if self.normalize:
|
781 |
-
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
782 |
-
return embed
|
783 |
-
|
784 |
-
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
785 |
-
# Trying to limit as much as possible sync points when the cache is warm.
|
786 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
787 |
-
if self.wav_cache is not None and no_undefined_paths:
|
788 |
-
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
789 |
-
paths = [Path(p) for p in x.path if p is not None]
|
790 |
-
self.wav_cache.populate_embed_cache(paths, x)
|
791 |
-
if self.text_cache is not None and no_undefined_paths:
|
792 |
-
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
793 |
-
paths = [Path(p) for p in x.path if p is not None]
|
794 |
-
self.text_cache.populate_embed_cache(paths, x)
|
795 |
-
return x
|
796 |
-
|
797 |
-
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
798 |
-
"""Extract shared latent representation from either the wav or the text using CLAP."""
|
799 |
-
# decide whether to use text embedding at train time or not
|
800 |
-
use_text_embed = random.random() < self.text_p
|
801 |
-
if self.training and not use_text_embed:
|
802 |
-
embed = self._get_wav_embedding(x)
|
803 |
-
empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
|
804 |
-
else:
|
805 |
-
embed = self._get_text_embedding(x)
|
806 |
-
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
|
807 |
-
return embed, empty_idx
|
808 |
-
|
809 |
-
|
810 |
-
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
|
811 |
-
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
812 |
-
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
813 |
-
If the condition is of any other type, set its value to None.
|
814 |
-
Works in-place.
|
815 |
-
"""
|
816 |
-
if condition_type not in ['text', 'wav', 'joint_embed']:
|
817 |
-
raise ValueError(
|
818 |
-
"dropout_condition got an unexpected condition type!"
|
819 |
-
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
820 |
-
)
|
821 |
-
|
822 |
-
if condition not in getattr(sample, condition_type):
|
823 |
-
raise ValueError(
|
824 |
-
"dropout_condition received an unexpected condition!"
|
825 |
-
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
826 |
-
f" but got '{condition}' of type '{condition_type}'!"
|
827 |
-
)
|
828 |
-
|
829 |
-
if condition_type == 'wav':
|
830 |
-
wav_cond = sample.wav[condition]
|
831 |
-
sample.wav[condition] = nullify_wav(wav_cond)
|
832 |
-
elif condition_type == 'joint_embed':
|
833 |
-
embed = sample.joint_embed[condition]
|
834 |
-
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
835 |
-
else:
|
836 |
-
sample.text[condition] = None
|
837 |
-
|
838 |
-
return sample
|
839 |
-
|
840 |
-
|
841 |
-
class DropoutModule(nn.Module):
|
842 |
-
"""Base module for all dropout modules."""
|
843 |
-
def __init__(self, seed: int = 1234):
|
844 |
-
super().__init__()
|
845 |
-
self.rng = torch.Generator()
|
846 |
-
self.rng.manual_seed(seed)
|
847 |
-
|
848 |
-
|
849 |
-
class AttributeDropout(DropoutModule):
|
850 |
-
"""Dropout with a given probability per attribute.
|
851 |
-
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
|
852 |
-
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
|
853 |
-
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
|
854 |
-
must also be dropped.
|
855 |
-
|
856 |
-
Args:
|
857 |
-
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
858 |
-
...
|
859 |
-
"genre": 0.1,
|
860 |
-
"artist": 0.5,
|
861 |
-
"wav": 0.25,
|
862 |
-
...
|
863 |
-
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
|
864 |
-
seed (int, optional): Random seed.
|
865 |
-
"""
|
866 |
-
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
|
867 |
-
super().__init__(seed=seed)
|
868 |
-
self.active_on_eval = active_on_eval
|
869 |
-
# construct dict that return the values from p otherwise 0
|
870 |
-
self.p = {}
|
871 |
-
for condition_type, probs in p.items():
|
872 |
-
self.p[condition_type] = defaultdict(lambda: 0, probs)
|
873 |
-
|
874 |
-
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
875 |
-
"""
|
876 |
-
Args:
|
877 |
-
samples (list[ConditioningAttributes]): List of conditions.
|
878 |
-
Returns:
|
879 |
-
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
880 |
-
"""
|
881 |
-
if not self.training and not self.active_on_eval:
|
882 |
-
return samples
|
883 |
-
|
884 |
-
samples = deepcopy(samples)
|
885 |
-
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
886 |
-
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
887 |
-
if torch.rand(1, generator=self.rng).item() < p:
|
888 |
-
for sample in samples:
|
889 |
-
dropout_condition(sample, condition_type, condition)
|
890 |
-
return samples
|
891 |
-
|
892 |
-
def __repr__(self):
|
893 |
-
return f"AttributeDropout({dict(self.p)})"
|
894 |
-
|
895 |
-
|
896 |
-
class ClassifierFreeGuidanceDropout(DropoutModule):
|
897 |
-
"""Classifier Free Guidance dropout.
|
898 |
-
All attributes are dropped with the same probability.
|
899 |
-
|
900 |
-
Args:
|
901 |
-
p (float): Probability to apply condition dropout during training.
|
902 |
-
seed (int): Random seed.
|
903 |
-
"""
|
904 |
-
def __init__(self, p: float, seed: int = 1234):
|
905 |
-
super().__init__(seed=seed)
|
906 |
-
self.p = p
|
907 |
-
|
908 |
-
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
909 |
-
"""
|
910 |
-
Args:
|
911 |
-
samples (list[ConditioningAttributes]): List of conditions.
|
912 |
-
Returns:
|
913 |
-
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
914 |
-
"""
|
915 |
-
if not self.training:
|
916 |
-
return samples
|
917 |
|
918 |
-
# decide on which attributes to drop in a batched fashion
|
919 |
-
drop = torch.rand(1, generator=self.rng).item() < self.p
|
920 |
-
if not drop:
|
921 |
-
return samples
|
922 |
|
923 |
-
# nullify conditions of all attributes
|
924 |
-
samples = deepcopy(samples)
|
925 |
-
for condition_type in ["wav", "text"]:
|
926 |
-
for sample in samples:
|
927 |
-
for condition in sample.attributes[condition_type]:
|
928 |
-
dropout_condition(sample, condition_type, condition)
|
929 |
-
return samples
|
930 |
|
931 |
-
def __repr__(self):
|
932 |
-
return f"ClassifierFreeGuidanceDropout(p={self.p})"
|
933 |
|
934 |
|
935 |
class ConditioningProvider(nn.Module):
|
@@ -1355,8 +921,8 @@ class LMModel(StreamingModule):
|
|
1355 |
**kwargs):
|
1356 |
super().__init__()
|
1357 |
self.cfg_coef = cfg_coef
|
1358 |
-
|
1359 |
-
|
1360 |
self.condition_provider = condition_provider
|
1361 |
self.fuser = fuser
|
1362 |
self.card = card
|
@@ -1447,10 +1013,7 @@ class LMModel(StreamingModule):
|
|
1447 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
1448 |
if condition_tensors is None:
|
1449 |
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
|
1450 |
-
|
1451 |
-
conditions = self.cfg_dropout(conditions)
|
1452 |
-
conditions = self.att_dropout(conditions)
|
1453 |
-
tokenized = self.condition_provider.tokenize(conditions)
|
1454 |
# encode conditions and fuse, both have a streaming cache to not recompute when generating.
|
1455 |
condition_tensors = self.condition_provider(tokenized)
|
1456 |
else:
|
@@ -1661,7 +1224,7 @@ class LMModel(StreamingModule):
|
|
1661 |
cfg_conditions: CFGConditions
|
1662 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
1663 |
if conditions:
|
1664 |
-
null_conditions =
|
1665 |
if two_step_cfg:
|
1666 |
cfg_conditions = (
|
1667 |
self.condition_provider(self.condition_provider.tokenize(conditions)),
|
|
|
322 |
...
|
323 |
|
324 |
|
|
|
|
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
|
328 |
class T5Conditioner(TextConditioner):
|
|
|
421 |
return embeds, mask
|
422 |
|
423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
|
427 |
|
|
|
489 |
return x
|
490 |
|
491 |
|
|
|
|
|
|
|
|
|
|
|
492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
|
|
|
|
|
|
|
|
|
497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
|
|
|
|
|
499 |
|
500 |
|
501 |
class ConditioningProvider(nn.Module):
|
|
|
921 |
**kwargs):
|
922 |
super().__init__()
|
923 |
self.cfg_coef = cfg_coef
|
924 |
+
|
925 |
+
|
926 |
self.condition_provider = condition_provider
|
927 |
self.fuser = fuser
|
928 |
self.card = card
|
|
|
1013 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
1014 |
if condition_tensors is None:
|
1015 |
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
|
1016 |
+
|
|
|
|
|
|
|
1017 |
# encode conditions and fuse, both have a streaming cache to not recompute when generating.
|
1018 |
condition_tensors = self.condition_provider(tokenized)
|
1019 |
else:
|
|
|
1224 |
cfg_conditions: CFGConditions
|
1225 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
1226 |
if conditions:
|
1227 |
+
null_conditions = conditions
|
1228 |
if two_step_cfg:
|
1229 |
cfg_conditions = (
|
1230 |
self.condition_provider(self.condition_provider.tokenize(conditions)),
|