Dionyssos commited on
Commit
06aa0fc
·
1 Parent(s): 7abc8f8

clean unused funs

Browse files
README.md CHANGED
@@ -59,10 +59,6 @@ python landscape2soundscape.py
59
 
60
  # Videos / Examples
61
 
62
- <iframe width="420" height="315"
63
- src="https://youtu.be/wWC8DpOKVvQ">
64
- </iframe>
65
-
66
  Video where Native voice is replaced with English TTS voice
67
 
68
 
 
59
 
60
  # Videos / Examples
61
 
 
 
 
 
62
  Video where Native voice is replaced with English TTS voice
63
 
64
 
audiocraft/builders.py CHANGED
@@ -28,7 +28,6 @@ from .codebooks_patterns import (
28
  )
29
  from .conditioners import (
30
  BaseConditioner,
31
- ChromaStemConditioner,
32
  CLAPEmbeddingConditioner,
33
  ConditionFuser,
34
  ConditioningProvider,
@@ -142,13 +141,13 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
142
  conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
143
  elif model_type == 'lut':
144
  conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
145
- elif model_type == 'chroma_stem':
146
- conditioners[str(cond)] = ChromaStemConditioner(
147
- output_dim=output_dim,
148
- duration=duration,
149
- device=device,
150
- **model_args
151
- )
152
  elif model_type == 'clap':
153
  conditioners[str(cond)] = CLAPEmbeddingConditioner(
154
  output_dim=output_dim,
@@ -158,6 +157,7 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
158
  else:
159
  raise ValueError(f"Unrecognized conditioning model: {model_type}")
160
  conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
 
161
  return conditioner
162
 
163
 
 
28
  )
29
  from .conditioners import (
30
  BaseConditioner,
 
31
  CLAPEmbeddingConditioner,
32
  ConditionFuser,
33
  ConditioningProvider,
 
141
  conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
142
  elif model_type == 'lut':
143
  conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
144
+ # elif model_type == 'chroma_stem':
145
+ # conditioners[str(cond)] = ChromaStemConditioner(
146
+ # output_dim=output_dim,
147
+ # duration=duration,
148
+ # device=device,
149
+ # **model_args
150
+ # )
151
  elif model_type == 'clap':
152
  conditioners[str(cond)] = CLAPEmbeddingConditioner(
153
  output_dim=output_dim,
 
157
  else:
158
  raise ValueError(f"Unrecognized conditioning model: {model_type}")
159
  conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
160
+ print(' COND\n',conditioner)
161
  return conditioner
162
 
163
 
audiocraft/chroma.py DELETED
@@ -1,66 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- import typing as tp
7
-
8
- from einops import rearrange
9
- from librosa import filters
10
- import torch
11
- from torch import nn
12
- import torch.nn.functional as F
13
- import torchaudio
14
-
15
-
16
- class ChromaExtractor(nn.Module):
17
- """Chroma extraction and quantization.
18
-
19
- Args:
20
- sample_rate (int): Sample rate for the chroma extraction.
21
- n_chroma (int): Number of chroma bins for the chroma extraction.
22
- radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
23
- nfft (int, optional): Number of FFT.
24
- winlen (int, optional): Window length.
25
- winhop (int, optional): Window hop size.
26
- argmax (bool, optional): Whether to use argmax. Defaults to False.
27
- norm (float, optional): Norm for chroma normalization. Defaults to inf.
28
- """
29
- def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
30
- winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
31
- norm: float = torch.inf):
32
- super().__init__()
33
- self.winlen = winlen or 2 ** radix2_exp
34
- self.nfft = nfft or self.winlen
35
- self.winhop = winhop or (self.winlen // 4)
36
- self.sample_rate = sample_rate
37
- self.n_chroma = n_chroma
38
- self.norm = norm
39
- self.argmax = argmax
40
- self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
41
- n_chroma=self.n_chroma)), persistent=False)
42
- self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
43
- hop_length=self.winhop, power=2, center=True,
44
- pad=0, normalized=True)
45
-
46
- def forward(self, wav: torch.Tensor) -> torch.Tensor:
47
- T = wav.shape[-1]
48
- # in case we are getting a wav that was dropped out (nullified)
49
- # from the conditioner, make sure wav length is no less that nfft
50
- if T < self.nfft:
51
- pad = self.nfft - T
52
- r = 0 if pad % 2 == 0 else 1
53
- wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
54
- assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
55
-
56
- spec = self.spec(wav).squeeze(1)
57
- raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
58
- norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
59
- norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
60
-
61
- if self.argmax:
62
- idx = norm_chroma.argmax(-1, keepdim=True)
63
- norm_chroma[:] = 0
64
- norm_chroma.scatter_(dim=-1, index=idx, value=1)
65
-
66
- return norm_chroma
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/conditioners.py CHANGED
@@ -26,7 +26,7 @@ import torch.nn.functional as F
26
  from torch.nn.utils.rnn import pad_sequence
27
  from .streaming import StreamingModule
28
 
29
- from .chroma import ChromaExtractor
30
  from .streaming import StreamingModule
31
  from .transformer import create_sin_embedding
32
 
@@ -500,195 +500,7 @@ class WaveformConditioner(BaseConditioner):
500
  return embeds, mask
501
 
502
 
503
- class ChromaStemConditioner(WaveformConditioner):
504
- """Chroma conditioner based on stems.
505
- The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
506
- the drums and bass often dominate the chroma leading to the chroma features
507
- not containing information about the melody.
508
-
509
- Args:
510
- output_dim (int): Output dimension for the conditioner.
511
- sample_rate (int): Sample rate for the chroma extractor.
512
- n_chroma (int): Number of chroma bins for the chroma extractor.
513
- radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
514
- duration (int): duration used during training. This is later used for correct padding
515
- in case we are using chroma as prefix.
516
- match_len_on_eval (bool, optional): if True then all chromas are padded to the training
517
- duration. Defaults to False.
518
- eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
519
- conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
520
- Defaults to None.
521
- n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
522
- device (tp.Union[torch.device, str], optional): Device for the conditioner.
523
- **kwargs: Additional parameters for the chroma extractor.
524
- """
525
- def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
526
- duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
527
- n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
528
- device: tp.Union[torch.device, str] = 'cpu', **kwargs):
529
- from demucs import pretrained
530
- super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
531
- self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
532
- self.sample_rate = sample_rate
533
- self.match_len_on_eval = match_len_on_eval
534
- if match_len_on_eval:
535
- self._use_masking = False
536
- self.duration = duration
537
- self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
538
- stem_sources: list = self.demucs.sources # type: ignore
539
- self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
540
- self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
541
- radix2_exp=radix2_exp, **kwargs).to(device)
542
- self.chroma_len = self._get_chroma_len()
543
- self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
544
- self.cache = None
545
- if cache_path is not None:
546
- self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
547
- compute_embed_fn=self._get_full_chroma_for_cache,
548
- extract_embed_fn=self._extract_chroma_chunk)
549
-
550
- def _downsampling_factor(self) -> int:
551
- return self.chroma.winhop
552
-
553
- def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
554
- """Load pre-defined waveforms from a json.
555
- These waveforms will be used for chroma extraction during evaluation.
556
- This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
557
- """
558
- if path is None:
559
- return None
560
-
561
- logger.info(f"Loading evaluation wavs from {path}")
562
- from audiocraft.data.audio_dataset import AudioDataset
563
- dataset: AudioDataset = AudioDataset.from_meta(
564
- path, segment_duration=self.duration, min_audio_duration=self.duration,
565
- sample_rate=self.sample_rate, channels=1)
566
-
567
- if len(dataset) > 0:
568
- eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
569
- logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
570
- return eval_wavs
571
- else:
572
- raise ValueError("Could not find evaluation wavs, check lengths of wavs")
573
-
574
- def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
575
- self.eval_wavs = eval_wavs
576
-
577
- def has_eval_wavs(self) -> bool:
578
- return self.eval_wavs is not None
579
-
580
- def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
581
- """Sample wavs from a predefined list."""
582
- assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
583
- total_eval_wavs = len(self.eval_wavs)
584
- out = self.eval_wavs
585
- if num_samples > total_eval_wavs:
586
- out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
587
- return out[torch.randperm(len(out))][:num_samples]
588
-
589
- def _get_chroma_len(self) -> int:
590
- """Get length of chroma during training."""
591
- dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
592
- dummy_chr = self.chroma(dummy_wav)
593
- return dummy_chr.shape[1]
594
-
595
- @torch.no_grad()
596
- def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
597
- """Get parts of the wav that holds the melody, extracting the main stems from the wav."""
598
- from demucs.apply import apply_model
599
- from demucs.audio import convert_audio
600
- with self.autocast:
601
- wav = convert_audio(
602
- wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
603
- stems = apply_model(self.demucs, wav, device=self.device)
604
- stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
605
- mix_wav = stems.sum(1) # merge extracted stems to single waveform
606
- mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
607
- return mix_wav
608
-
609
- @torch.no_grad()
610
- def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
611
- """Extract chroma features from the waveform."""
612
- with self.autocast:
613
- return self.chroma(wav)
614
-
615
- @torch.no_grad()
616
- def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
617
- """Compute wav embedding, applying stem and chroma extraction."""
618
- # avoid 0-size tensors when we are working with null conds
619
- if wav.shape[-1] == 1:
620
- return self._extract_chroma(wav)
621
- stems = self._get_stemmed_wav(wav, sample_rate)
622
- chroma = self._extract_chroma(stems)
623
- return chroma
624
-
625
- @torch.no_grad()
626
- def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
627
- """Extract chroma from the whole audio waveform at the given path."""
628
- wav, sr = soundfile.read(path)
629
- wav = wav[None].to(self.device)
630
- wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
631
- chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
632
- return chroma
633
-
634
- def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
635
- """Extract a chunk of chroma from the full chroma derived from the full waveform."""
636
- wav_length = x.wav.shape[-1]
637
- seek_time = x.seek_time[idx]
638
- assert seek_time is not None, (
639
- "WavCondition seek_time is required "
640
- "when extracting chroma chunks from pre-computed chroma.")
641
- full_chroma = full_chroma.float()
642
- frame_rate = self.sample_rate / self._downsampling_factor()
643
- target_length = int(frame_rate * wav_length / self.sample_rate)
644
- index = int(frame_rate * seek_time)
645
- out = full_chroma[index: index + target_length]
646
- out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
647
- return out.to(self.device)
648
-
649
- @torch.no_grad()
650
- def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
651
- """Get the wav embedding from the WavCondition.
652
- The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
653
- or will rely on the embedding cache to load the pre-computed embedding if relevant.
654
- """
655
- sampled_wav: tp.Optional[torch.Tensor] = None
656
- if not self.training and self.eval_wavs is not None:
657
- warn_once(logger, "Using precomputed evaluation wavs!")
658
- sampled_wav = self._sample_eval_wavs(len(x.wav))
659
-
660
- no_undefined_paths = all(p is not None for p in x.path)
661
- no_nullified_cond = x.wav.shape[-1] > 1
662
- if sampled_wav is not None:
663
- chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
664
- elif self.cache is not None and no_undefined_paths and no_nullified_cond:
665
- paths = [Path(p) for p in x.path if p is not None]
666
- chroma = self.cache.get_embed_from_cache(paths, x)
667
- else:
668
- assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
669
- chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
670
-
671
- if self.match_len_on_eval:
672
- B, T, C = chroma.shape
673
- if T > self.chroma_len:
674
- chroma = chroma[:, :self.chroma_len]
675
- logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
676
- elif T < self.chroma_len:
677
- n_repeat = int(math.ceil(self.chroma_len / T))
678
- chroma = chroma.repeat(1, n_repeat, 1)
679
- chroma = chroma[:, :self.chroma_len]
680
- logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
681
-
682
- return chroma
683
 
684
- def tokenize(self, x: WavCondition) -> WavCondition:
685
- """Apply WavConditioner tokenization and populate cache if needed."""
686
- x = super().tokenize(x)
687
- no_undefined_paths = all(p is not None for p in x.path)
688
- if self.cache is not None and no_undefined_paths:
689
- paths = [Path(p) for p in x.path if p is not None]
690
- self.cache.populate_embed_cache(paths, x)
691
- return x
692
 
693
 
694
  class JointEmbeddingConditioner(BaseConditioner):
 
26
  from torch.nn.utils.rnn import pad_sequence
27
  from .streaming import StreamingModule
28
 
29
+
30
  from .streaming import StreamingModule
31
  from .transformer import create_sin_embedding
32
 
 
500
  return embeds, mask
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
 
 
 
 
 
 
 
 
504
 
505
 
506
  class JointEmbeddingConditioner(BaseConditioner):