clean unused funs
Browse files- README.md +0 -4
- audiocraft/builders.py +8 -8
- audiocraft/chroma.py +0 -66
- audiocraft/conditioners.py +1 -189
README.md
CHANGED
@@ -59,10 +59,6 @@ python landscape2soundscape.py
|
|
59 |
|
60 |
# Videos / Examples
|
61 |
|
62 |
-
<iframe width="420" height="315"
|
63 |
-
src="https://youtu.be/wWC8DpOKVvQ">
|
64 |
-
</iframe>
|
65 |
-
|
66 |
Video where Native voice is replaced with English TTS voice
|
67 |
|
68 |
|
|
|
59 |
|
60 |
# Videos / Examples
|
61 |
|
|
|
|
|
|
|
|
|
62 |
Video where Native voice is replaced with English TTS voice
|
63 |
|
64 |
|
audiocraft/builders.py
CHANGED
@@ -28,7 +28,6 @@ from .codebooks_patterns import (
|
|
28 |
)
|
29 |
from .conditioners import (
|
30 |
BaseConditioner,
|
31 |
-
ChromaStemConditioner,
|
32 |
CLAPEmbeddingConditioner,
|
33 |
ConditionFuser,
|
34 |
ConditioningProvider,
|
@@ -142,13 +141,13 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
|
|
142 |
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
143 |
elif model_type == 'lut':
|
144 |
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
145 |
-
elif model_type == 'chroma_stem':
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
elif model_type == 'clap':
|
153 |
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
154 |
output_dim=output_dim,
|
@@ -158,6 +157,7 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
|
|
158 |
else:
|
159 |
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
160 |
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
|
|
161 |
return conditioner
|
162 |
|
163 |
|
|
|
28 |
)
|
29 |
from .conditioners import (
|
30 |
BaseConditioner,
|
|
|
31 |
CLAPEmbeddingConditioner,
|
32 |
ConditionFuser,
|
33 |
ConditioningProvider,
|
|
|
141 |
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
142 |
elif model_type == 'lut':
|
143 |
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
144 |
+
# elif model_type == 'chroma_stem':
|
145 |
+
# conditioners[str(cond)] = ChromaStemConditioner(
|
146 |
+
# output_dim=output_dim,
|
147 |
+
# duration=duration,
|
148 |
+
# device=device,
|
149 |
+
# **model_args
|
150 |
+
# )
|
151 |
elif model_type == 'clap':
|
152 |
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
153 |
output_dim=output_dim,
|
|
|
157 |
else:
|
158 |
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
159 |
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
160 |
+
print(' COND\n',conditioner)
|
161 |
return conditioner
|
162 |
|
163 |
|
audiocraft/chroma.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
import typing as tp
|
7 |
-
|
8 |
-
from einops import rearrange
|
9 |
-
from librosa import filters
|
10 |
-
import torch
|
11 |
-
from torch import nn
|
12 |
-
import torch.nn.functional as F
|
13 |
-
import torchaudio
|
14 |
-
|
15 |
-
|
16 |
-
class ChromaExtractor(nn.Module):
|
17 |
-
"""Chroma extraction and quantization.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
sample_rate (int): Sample rate for the chroma extraction.
|
21 |
-
n_chroma (int): Number of chroma bins for the chroma extraction.
|
22 |
-
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
|
23 |
-
nfft (int, optional): Number of FFT.
|
24 |
-
winlen (int, optional): Window length.
|
25 |
-
winhop (int, optional): Window hop size.
|
26 |
-
argmax (bool, optional): Whether to use argmax. Defaults to False.
|
27 |
-
norm (float, optional): Norm for chroma normalization. Defaults to inf.
|
28 |
-
"""
|
29 |
-
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
|
30 |
-
winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
|
31 |
-
norm: float = torch.inf):
|
32 |
-
super().__init__()
|
33 |
-
self.winlen = winlen or 2 ** radix2_exp
|
34 |
-
self.nfft = nfft or self.winlen
|
35 |
-
self.winhop = winhop or (self.winlen // 4)
|
36 |
-
self.sample_rate = sample_rate
|
37 |
-
self.n_chroma = n_chroma
|
38 |
-
self.norm = norm
|
39 |
-
self.argmax = argmax
|
40 |
-
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
|
41 |
-
n_chroma=self.n_chroma)), persistent=False)
|
42 |
-
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
|
43 |
-
hop_length=self.winhop, power=2, center=True,
|
44 |
-
pad=0, normalized=True)
|
45 |
-
|
46 |
-
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
47 |
-
T = wav.shape[-1]
|
48 |
-
# in case we are getting a wav that was dropped out (nullified)
|
49 |
-
# from the conditioner, make sure wav length is no less that nfft
|
50 |
-
if T < self.nfft:
|
51 |
-
pad = self.nfft - T
|
52 |
-
r = 0 if pad % 2 == 0 else 1
|
53 |
-
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
|
54 |
-
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
|
55 |
-
|
56 |
-
spec = self.spec(wav).squeeze(1)
|
57 |
-
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
|
58 |
-
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
|
59 |
-
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
|
60 |
-
|
61 |
-
if self.argmax:
|
62 |
-
idx = norm_chroma.argmax(-1, keepdim=True)
|
63 |
-
norm_chroma[:] = 0
|
64 |
-
norm_chroma.scatter_(dim=-1, index=idx, value=1)
|
65 |
-
|
66 |
-
return norm_chroma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/conditioners.py
CHANGED
@@ -26,7 +26,7 @@ import torch.nn.functional as F
|
|
26 |
from torch.nn.utils.rnn import pad_sequence
|
27 |
from .streaming import StreamingModule
|
28 |
|
29 |
-
|
30 |
from .streaming import StreamingModule
|
31 |
from .transformer import create_sin_embedding
|
32 |
|
@@ -500,195 +500,7 @@ class WaveformConditioner(BaseConditioner):
|
|
500 |
return embeds, mask
|
501 |
|
502 |
|
503 |
-
class ChromaStemConditioner(WaveformConditioner):
|
504 |
-
"""Chroma conditioner based on stems.
|
505 |
-
The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
|
506 |
-
the drums and bass often dominate the chroma leading to the chroma features
|
507 |
-
not containing information about the melody.
|
508 |
-
|
509 |
-
Args:
|
510 |
-
output_dim (int): Output dimension for the conditioner.
|
511 |
-
sample_rate (int): Sample rate for the chroma extractor.
|
512 |
-
n_chroma (int): Number of chroma bins for the chroma extractor.
|
513 |
-
radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
|
514 |
-
duration (int): duration used during training. This is later used for correct padding
|
515 |
-
in case we are using chroma as prefix.
|
516 |
-
match_len_on_eval (bool, optional): if True then all chromas are padded to the training
|
517 |
-
duration. Defaults to False.
|
518 |
-
eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
|
519 |
-
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
|
520 |
-
Defaults to None.
|
521 |
-
n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
|
522 |
-
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
523 |
-
**kwargs: Additional parameters for the chroma extractor.
|
524 |
-
"""
|
525 |
-
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
|
526 |
-
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
|
527 |
-
n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
|
528 |
-
device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
529 |
-
from demucs import pretrained
|
530 |
-
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
|
531 |
-
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
|
532 |
-
self.sample_rate = sample_rate
|
533 |
-
self.match_len_on_eval = match_len_on_eval
|
534 |
-
if match_len_on_eval:
|
535 |
-
self._use_masking = False
|
536 |
-
self.duration = duration
|
537 |
-
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
|
538 |
-
stem_sources: list = self.demucs.sources # type: ignore
|
539 |
-
self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
|
540 |
-
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
|
541 |
-
radix2_exp=radix2_exp, **kwargs).to(device)
|
542 |
-
self.chroma_len = self._get_chroma_len()
|
543 |
-
self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
|
544 |
-
self.cache = None
|
545 |
-
if cache_path is not None:
|
546 |
-
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
547 |
-
compute_embed_fn=self._get_full_chroma_for_cache,
|
548 |
-
extract_embed_fn=self._extract_chroma_chunk)
|
549 |
-
|
550 |
-
def _downsampling_factor(self) -> int:
|
551 |
-
return self.chroma.winhop
|
552 |
-
|
553 |
-
def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
|
554 |
-
"""Load pre-defined waveforms from a json.
|
555 |
-
These waveforms will be used for chroma extraction during evaluation.
|
556 |
-
This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
|
557 |
-
"""
|
558 |
-
if path is None:
|
559 |
-
return None
|
560 |
-
|
561 |
-
logger.info(f"Loading evaluation wavs from {path}")
|
562 |
-
from audiocraft.data.audio_dataset import AudioDataset
|
563 |
-
dataset: AudioDataset = AudioDataset.from_meta(
|
564 |
-
path, segment_duration=self.duration, min_audio_duration=self.duration,
|
565 |
-
sample_rate=self.sample_rate, channels=1)
|
566 |
-
|
567 |
-
if len(dataset) > 0:
|
568 |
-
eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
|
569 |
-
logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
|
570 |
-
return eval_wavs
|
571 |
-
else:
|
572 |
-
raise ValueError("Could not find evaluation wavs, check lengths of wavs")
|
573 |
-
|
574 |
-
def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
|
575 |
-
self.eval_wavs = eval_wavs
|
576 |
-
|
577 |
-
def has_eval_wavs(self) -> bool:
|
578 |
-
return self.eval_wavs is not None
|
579 |
-
|
580 |
-
def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
|
581 |
-
"""Sample wavs from a predefined list."""
|
582 |
-
assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
|
583 |
-
total_eval_wavs = len(self.eval_wavs)
|
584 |
-
out = self.eval_wavs
|
585 |
-
if num_samples > total_eval_wavs:
|
586 |
-
out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
|
587 |
-
return out[torch.randperm(len(out))][:num_samples]
|
588 |
-
|
589 |
-
def _get_chroma_len(self) -> int:
|
590 |
-
"""Get length of chroma during training."""
|
591 |
-
dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
|
592 |
-
dummy_chr = self.chroma(dummy_wav)
|
593 |
-
return dummy_chr.shape[1]
|
594 |
-
|
595 |
-
@torch.no_grad()
|
596 |
-
def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
597 |
-
"""Get parts of the wav that holds the melody, extracting the main stems from the wav."""
|
598 |
-
from demucs.apply import apply_model
|
599 |
-
from demucs.audio import convert_audio
|
600 |
-
with self.autocast:
|
601 |
-
wav = convert_audio(
|
602 |
-
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
603 |
-
stems = apply_model(self.demucs, wav, device=self.device)
|
604 |
-
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
|
605 |
-
mix_wav = stems.sum(1) # merge extracted stems to single waveform
|
606 |
-
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
607 |
-
return mix_wav
|
608 |
-
|
609 |
-
@torch.no_grad()
|
610 |
-
def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
|
611 |
-
"""Extract chroma features from the waveform."""
|
612 |
-
with self.autocast:
|
613 |
-
return self.chroma(wav)
|
614 |
-
|
615 |
-
@torch.no_grad()
|
616 |
-
def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
617 |
-
"""Compute wav embedding, applying stem and chroma extraction."""
|
618 |
-
# avoid 0-size tensors when we are working with null conds
|
619 |
-
if wav.shape[-1] == 1:
|
620 |
-
return self._extract_chroma(wav)
|
621 |
-
stems = self._get_stemmed_wav(wav, sample_rate)
|
622 |
-
chroma = self._extract_chroma(stems)
|
623 |
-
return chroma
|
624 |
-
|
625 |
-
@torch.no_grad()
|
626 |
-
def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
|
627 |
-
"""Extract chroma from the whole audio waveform at the given path."""
|
628 |
-
wav, sr = soundfile.read(path)
|
629 |
-
wav = wav[None].to(self.device)
|
630 |
-
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
|
631 |
-
chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
|
632 |
-
return chroma
|
633 |
-
|
634 |
-
def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
|
635 |
-
"""Extract a chunk of chroma from the full chroma derived from the full waveform."""
|
636 |
-
wav_length = x.wav.shape[-1]
|
637 |
-
seek_time = x.seek_time[idx]
|
638 |
-
assert seek_time is not None, (
|
639 |
-
"WavCondition seek_time is required "
|
640 |
-
"when extracting chroma chunks from pre-computed chroma.")
|
641 |
-
full_chroma = full_chroma.float()
|
642 |
-
frame_rate = self.sample_rate / self._downsampling_factor()
|
643 |
-
target_length = int(frame_rate * wav_length / self.sample_rate)
|
644 |
-
index = int(frame_rate * seek_time)
|
645 |
-
out = full_chroma[index: index + target_length]
|
646 |
-
out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
|
647 |
-
return out.to(self.device)
|
648 |
-
|
649 |
-
@torch.no_grad()
|
650 |
-
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
651 |
-
"""Get the wav embedding from the WavCondition.
|
652 |
-
The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
|
653 |
-
or will rely on the embedding cache to load the pre-computed embedding if relevant.
|
654 |
-
"""
|
655 |
-
sampled_wav: tp.Optional[torch.Tensor] = None
|
656 |
-
if not self.training and self.eval_wavs is not None:
|
657 |
-
warn_once(logger, "Using precomputed evaluation wavs!")
|
658 |
-
sampled_wav = self._sample_eval_wavs(len(x.wav))
|
659 |
-
|
660 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
661 |
-
no_nullified_cond = x.wav.shape[-1] > 1
|
662 |
-
if sampled_wav is not None:
|
663 |
-
chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
|
664 |
-
elif self.cache is not None and no_undefined_paths and no_nullified_cond:
|
665 |
-
paths = [Path(p) for p in x.path if p is not None]
|
666 |
-
chroma = self.cache.get_embed_from_cache(paths, x)
|
667 |
-
else:
|
668 |
-
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
|
669 |
-
chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
|
670 |
-
|
671 |
-
if self.match_len_on_eval:
|
672 |
-
B, T, C = chroma.shape
|
673 |
-
if T > self.chroma_len:
|
674 |
-
chroma = chroma[:, :self.chroma_len]
|
675 |
-
logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
|
676 |
-
elif T < self.chroma_len:
|
677 |
-
n_repeat = int(math.ceil(self.chroma_len / T))
|
678 |
-
chroma = chroma.repeat(1, n_repeat, 1)
|
679 |
-
chroma = chroma[:, :self.chroma_len]
|
680 |
-
logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
|
681 |
-
|
682 |
-
return chroma
|
683 |
|
684 |
-
def tokenize(self, x: WavCondition) -> WavCondition:
|
685 |
-
"""Apply WavConditioner tokenization and populate cache if needed."""
|
686 |
-
x = super().tokenize(x)
|
687 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
688 |
-
if self.cache is not None and no_undefined_paths:
|
689 |
-
paths = [Path(p) for p in x.path if p is not None]
|
690 |
-
self.cache.populate_embed_cache(paths, x)
|
691 |
-
return x
|
692 |
|
693 |
|
694 |
class JointEmbeddingConditioner(BaseConditioner):
|
|
|
26 |
from torch.nn.utils.rnn import pad_sequence
|
27 |
from .streaming import StreamingModule
|
28 |
|
29 |
+
|
30 |
from .streaming import StreamingModule
|
31 |
from .transformer import create_sin_embedding
|
32 |
|
|
|
500 |
return embeds, mask
|
501 |
|
502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
|
505 |
|
506 |
class JointEmbeddingConditioner(BaseConditioner):
|