|
from pathlib import Path |
|
from typing import Callable |
|
from typing import Dict |
|
from typing import List |
|
from typing import Union |
|
|
|
import numpy as np |
|
from torch.utils.data import SequentialSampler |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from ..core import AudioSignal |
|
from ..core import util |
|
|
|
|
|
class AudioLoader: |
|
"""Loads audio endlessly from a list of audio sources |
|
containing paths to audio files. Audio sources can be |
|
folders full of audio files (which are found via file |
|
extension) or by providing a CSV file which contains paths |
|
to audio files. |
|
|
|
Parameters |
|
---------- |
|
sources : List[str], optional |
|
Sources containing folders, or CSVs with |
|
paths to audio files, by default None |
|
weights : List[float], optional |
|
Weights to sample audio files from each source, by default None |
|
relative_path : str, optional |
|
Path audio should be loaded relative to, by default "" |
|
transform : Callable, optional |
|
Transform to instantiate alongside audio sample, |
|
by default None |
|
ext : List[str] |
|
List of extensions to find audio within each source by. Can |
|
also be a file name (e.g. "vocals.wav"). by default |
|
``['.wav', '.flac', '.mp3', '.mp4']``. |
|
shuffle: bool |
|
Whether to shuffle the files within the dataloader. Defaults to True. |
|
shuffle_state: int |
|
State to use to seed the shuffle of the files. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sources: List[str] = None, |
|
weights: List[float] = None, |
|
transform: Callable = None, |
|
relative_path: str = "", |
|
ext: List[str] = util.AUDIO_EXTENSIONS, |
|
shuffle: bool = True, |
|
shuffle_state: int = 0, |
|
): |
|
self.audio_lists = util.read_sources( |
|
sources, relative_path=relative_path, ext=ext |
|
) |
|
|
|
self.audio_indices = [ |
|
(src_idx, item_idx) |
|
for src_idx, src in enumerate(self.audio_lists) |
|
for item_idx in range(len(src)) |
|
] |
|
if shuffle: |
|
state = util.random_state(shuffle_state) |
|
state.shuffle(self.audio_indices) |
|
|
|
self.sources = sources |
|
self.weights = weights |
|
self.transform = transform |
|
|
|
def __call__( |
|
self, |
|
state, |
|
sample_rate: int, |
|
duration: float, |
|
loudness_cutoff: float = -40, |
|
num_channels: int = 1, |
|
offset: float = None, |
|
source_idx: int = None, |
|
item_idx: int = None, |
|
global_idx: int = None, |
|
): |
|
if source_idx is not None and item_idx is not None: |
|
try: |
|
audio_info = self.audio_lists[source_idx][item_idx] |
|
except: |
|
audio_info = {"path": "none"} |
|
elif global_idx is not None: |
|
source_idx, item_idx = self.audio_indices[ |
|
global_idx % len(self.audio_indices) |
|
] |
|
audio_info = self.audio_lists[source_idx][item_idx] |
|
else: |
|
audio_info, source_idx, item_idx = util.choose_from_list_of_lists( |
|
state, self.audio_lists, p=self.weights |
|
) |
|
|
|
path = audio_info["path"] |
|
signal = AudioSignal.zeros(duration, sample_rate, num_channels) |
|
|
|
if path != "none": |
|
if offset is None: |
|
signal = AudioSignal.salient_excerpt( |
|
path, |
|
duration=duration, |
|
state=state, |
|
loudness_cutoff=loudness_cutoff, |
|
) |
|
else: |
|
signal = AudioSignal( |
|
path, |
|
offset=offset, |
|
duration=duration, |
|
) |
|
|
|
if num_channels == 1: |
|
signal = signal.to_mono() |
|
signal = signal.resample(sample_rate) |
|
|
|
if signal.duration < duration: |
|
signal = signal.zero_pad_to(int(duration * sample_rate)) |
|
|
|
for k, v in audio_info.items(): |
|
signal.metadata[k] = v |
|
|
|
item = { |
|
"signal": signal, |
|
"source_idx": source_idx, |
|
"item_idx": item_idx, |
|
"source": str(self.sources[source_idx]), |
|
"path": str(path), |
|
} |
|
if self.transform is not None: |
|
item["transform_args"] = self.transform.instantiate(state, signal=signal) |
|
return item |
|
|
|
|
|
def default_matcher(x, y): |
|
return Path(x).parent == Path(y).parent |
|
|
|
|
|
def align_lists(lists, matcher: Callable = default_matcher): |
|
longest_list = lists[np.argmax([len(l) for l in lists])] |
|
for i, x in enumerate(longest_list): |
|
for l in lists: |
|
if i >= len(l): |
|
l.append({"path": "none"}) |
|
elif not matcher(l[i]["path"], x["path"]): |
|
l.insert(i, {"path": "none"}) |
|
return lists |
|
|
|
|
|
class AudioDataset: |
|
"""Loads audio from multiple loaders (with associated transforms) |
|
for a specified number of samples. Excerpts are drawn randomly |
|
of the specified duration, above a specified loudness threshold |
|
and are resampled on the fly to the desired sample rate |
|
(if it is different from the audio source sample rate). |
|
|
|
This takes either a single AudioLoader object, |
|
a dictionary of AudioLoader objects, or a dictionary of AudioLoader |
|
objects. Each AudioLoader is called by the dataset, and the |
|
result is placed in the output dictionary. A transform can also be |
|
specified for the entire dataset, rather than for each specific |
|
loader. This transform can be applied to the output of all the |
|
loaders if desired. |
|
|
|
AudioLoader objects can be specified as aligned, which means the |
|
loaders correspond to multitrack audio (e.g. a vocals, bass, |
|
drums, and other loader for multitrack music mixtures). |
|
|
|
|
|
Parameters |
|
---------- |
|
loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]] |
|
AudioLoaders to sample audio from. |
|
sample_rate : int |
|
Desired sample rate. |
|
n_examples : int, optional |
|
Number of examples (length of dataset), by default 1000 |
|
duration : float, optional |
|
Duration of audio samples, by default 0.5 |
|
loudness_cutoff : float, optional |
|
Loudness cutoff threshold for audio samples, by default -40 |
|
num_channels : int, optional |
|
Number of channels in output audio, by default 1 |
|
transform : Callable, optional |
|
Transform to instantiate alongside each dataset item, by default None |
|
aligned : bool, optional |
|
Whether the loaders should be sampled in an aligned manner (e.g. same |
|
offset, duration, and matched file name), by default False |
|
shuffle_loaders : bool, optional |
|
Whether to shuffle the loaders before sampling from them, by default False |
|
matcher : Callable |
|
How to match files from adjacent audio lists (e.g. for a multitrack audio loader), |
|
by default uses the parent directory of each file. |
|
without_replacement : bool |
|
Whether to choose files with or without replacement, by default True. |
|
|
|
|
|
Examples |
|
-------- |
|
>>> from audiotools.data.datasets import AudioLoader |
|
>>> from audiotools.data.datasets import AudioDataset |
|
>>> from audiotools import transforms as tfm |
|
>>> import numpy as np |
|
>>> |
|
>>> loaders = [ |
|
>>> AudioLoader( |
|
>>> sources=[f"tests/audio/spk"], |
|
>>> transform=tfm.Equalizer(), |
|
>>> ext=["wav"], |
|
>>> ) |
|
>>> for i in range(5) |
|
>>> ] |
|
>>> |
|
>>> dataset = AudioDataset( |
|
>>> loaders = loaders, |
|
>>> sample_rate = 44100, |
|
>>> duration = 1.0, |
|
>>> transform = tfm.RescaleAudio(), |
|
>>> ) |
|
>>> |
|
>>> item = dataset[np.random.randint(len(dataset))] |
|
>>> |
|
>>> for i in range(len(loaders)): |
|
>>> item[i]["signal"] = loaders[i].transform( |
|
>>> item[i]["signal"], **item[i]["transform_args"] |
|
>>> ) |
|
>>> item[i]["signal"].widget(i) |
|
>>> |
|
>>> mix = sum([item[i]["signal"] for i in range(len(loaders))]) |
|
>>> mix = dataset.transform(mix, **item["transform_args"]) |
|
>>> mix.widget("mix") |
|
|
|
Below is an example of how one could load MUSDB multitrack data: |
|
|
|
>>> import audiotools as at |
|
>>> from pathlib import Path |
|
>>> from audiotools import transforms as tfm |
|
>>> import numpy as np |
|
>>> import torch |
|
>>> |
|
>>> def build_dataset( |
|
>>> sample_rate: int = 44100, |
|
>>> duration: float = 5.0, |
|
>>> musdb_path: str = "~/.data/musdb/", |
|
>>> ): |
|
>>> musdb_path = Path(musdb_path).expanduser() |
|
>>> loaders = { |
|
>>> src: at.datasets.AudioLoader( |
|
>>> sources=[musdb_path], |
|
>>> transform=tfm.Compose( |
|
>>> tfm.VolumeNorm(("uniform", -20, -10)), |
|
>>> tfm.Silence(prob=0.1), |
|
>>> ), |
|
>>> ext=[f"{src}.wav"], |
|
>>> ) |
|
>>> for src in ["vocals", "bass", "drums", "other"] |
|
>>> } |
|
>>> |
|
>>> dataset = at.datasets.AudioDataset( |
|
>>> loaders=loaders, |
|
>>> sample_rate=sample_rate, |
|
>>> duration=duration, |
|
>>> num_channels=1, |
|
>>> aligned=True, |
|
>>> transform=tfm.RescaleAudio(), |
|
>>> shuffle_loaders=True, |
|
>>> ) |
|
>>> return dataset, list(loaders.keys()) |
|
>>> |
|
>>> train_data, sources = build_dataset() |
|
>>> dataloader = torch.utils.data.DataLoader( |
|
>>> train_data, |
|
>>> batch_size=16, |
|
>>> num_workers=0, |
|
>>> collate_fn=train_data.collate, |
|
>>> ) |
|
>>> batch = next(iter(dataloader)) |
|
>>> |
|
>>> for k in sources: |
|
>>> src = batch[k] |
|
>>> src["transformed"] = train_data.loaders[k].transform( |
|
>>> src["signal"].clone(), **src["transform_args"] |
|
>>> ) |
|
>>> |
|
>>> mixture = sum(batch[k]["transformed"] for k in sources) |
|
>>> mixture = train_data.transform(mixture, **batch["transform_args"]) |
|
>>> |
|
>>> # Say a model takes the mix and gives back (n_batch, n_src, n_time). |
|
>>> # Construct the targets: |
|
>>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1) |
|
|
|
Similarly, here's example code for loading Slakh data: |
|
|
|
>>> import audiotools as at |
|
>>> from pathlib import Path |
|
>>> from audiotools import transforms as tfm |
|
>>> import numpy as np |
|
>>> import torch |
|
>>> import glob |
|
>>> |
|
>>> def build_dataset( |
|
>>> sample_rate: int = 16000, |
|
>>> duration: float = 10.0, |
|
>>> slakh_path: str = "~/.data/slakh/", |
|
>>> ): |
|
>>> slakh_path = Path(slakh_path).expanduser() |
|
>>> |
|
>>> # Find the max number of sources in Slakh |
|
>>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)] |
|
>>> n_sources = len(list(set(src_names))) |
|
>>> |
|
>>> loaders = { |
|
>>> f"S{i:02d}": at.datasets.AudioLoader( |
|
>>> sources=[slakh_path], |
|
>>> transform=tfm.Compose( |
|
>>> tfm.VolumeNorm(("uniform", -20, -10)), |
|
>>> tfm.Silence(prob=0.1), |
|
>>> ), |
|
>>> ext=[f"S{i:02d}.wav"], |
|
>>> ) |
|
>>> for i in range(n_sources) |
|
>>> } |
|
>>> dataset = at.datasets.AudioDataset( |
|
>>> loaders=loaders, |
|
>>> sample_rate=sample_rate, |
|
>>> duration=duration, |
|
>>> num_channels=1, |
|
>>> aligned=True, |
|
>>> transform=tfm.RescaleAudio(), |
|
>>> shuffle_loaders=False, |
|
>>> ) |
|
>>> |
|
>>> return dataset, list(loaders.keys()) |
|
>>> |
|
>>> train_data, sources = build_dataset() |
|
>>> dataloader = torch.utils.data.DataLoader( |
|
>>> train_data, |
|
>>> batch_size=16, |
|
>>> num_workers=0, |
|
>>> collate_fn=train_data.collate, |
|
>>> ) |
|
>>> batch = next(iter(dataloader)) |
|
>>> |
|
>>> for k in sources: |
|
>>> src = batch[k] |
|
>>> src["transformed"] = train_data.loaders[k].transform( |
|
>>> src["signal"].clone(), **src["transform_args"] |
|
>>> ) |
|
>>> |
|
>>> mixture = sum(batch[k]["transformed"] for k in sources) |
|
>>> mixture = train_data.transform(mixture, **batch["transform_args"]) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]], |
|
sample_rate: int, |
|
n_examples: int = 1000, |
|
duration: float = 0.5, |
|
offset: float = None, |
|
loudness_cutoff: float = -40, |
|
num_channels: int = 1, |
|
transform: Callable = None, |
|
aligned: bool = False, |
|
shuffle_loaders: bool = False, |
|
matcher: Callable = default_matcher, |
|
without_replacement: bool = True, |
|
): |
|
|
|
if isinstance(loaders, list): |
|
loaders = {i: l for i, l in enumerate(loaders)} |
|
elif isinstance(loaders, AudioLoader): |
|
loaders = {0: loaders} |
|
|
|
self.loaders = loaders |
|
self.loudness_cutoff = loudness_cutoff |
|
self.num_channels = num_channels |
|
|
|
self.length = n_examples |
|
self.transform = transform |
|
self.sample_rate = sample_rate |
|
self.duration = duration |
|
self.offset = offset |
|
self.aligned = aligned |
|
self.shuffle_loaders = shuffle_loaders |
|
self.without_replacement = without_replacement |
|
|
|
if aligned: |
|
loaders_list = list(loaders.values()) |
|
for i in range(len(loaders_list[0].audio_lists)): |
|
input_lists = [l.audio_lists[i] for l in loaders_list] |
|
|
|
align_lists(input_lists, matcher) |
|
|
|
def __getitem__(self, idx): |
|
state = util.random_state(idx) |
|
offset = None if self.offset is None else self.offset |
|
item = {} |
|
|
|
keys = list(self.loaders.keys()) |
|
if self.shuffle_loaders: |
|
state.shuffle(keys) |
|
|
|
loader_kwargs = { |
|
"state": state, |
|
"sample_rate": self.sample_rate, |
|
"duration": self.duration, |
|
"loudness_cutoff": self.loudness_cutoff, |
|
"num_channels": self.num_channels, |
|
"global_idx": idx if self.without_replacement else None, |
|
} |
|
|
|
|
|
loader = self.loaders[keys[0]] |
|
item[keys[0]] = loader(**loader_kwargs) |
|
|
|
for key in keys[1:]: |
|
loader = self.loaders[key] |
|
if self.aligned: |
|
|
|
|
|
offset = item[keys[0]]["signal"].metadata["offset"] |
|
loader_kwargs.update( |
|
{ |
|
"offset": offset, |
|
"source_idx": item[keys[0]]["source_idx"], |
|
"item_idx": item[keys[0]]["item_idx"], |
|
} |
|
) |
|
item[key] = loader(**loader_kwargs) |
|
|
|
|
|
keys = list(self.loaders.keys()) |
|
item = {k: item[k] for k in keys} |
|
|
|
item["idx"] = idx |
|
if self.transform is not None: |
|
item["transform_args"] = self.transform.instantiate( |
|
state=state, signal=item[keys[0]]["signal"] |
|
) |
|
|
|
|
|
|
|
|
|
if len(keys) == 1: |
|
item.update(item.pop(keys[0])) |
|
|
|
return item |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
@staticmethod |
|
def collate(list_of_dicts: Union[list, dict], n_splits: int = None): |
|
"""Collates items drawn from this dataset. Uses |
|
:py:func:`audiotools.core.util.collate`. |
|
|
|
Parameters |
|
---------- |
|
list_of_dicts : typing.Union[list, dict] |
|
Data drawn from each item. |
|
n_splits : int |
|
Number of splits to make when creating the batches (split into |
|
sub-batches). Useful for things like gradient accumulation. |
|
|
|
Returns |
|
------- |
|
dict |
|
Dictionary of batched data. |
|
""" |
|
return util.collate(list_of_dicts, n_splits=n_splits) |
|
|
|
|
|
class ConcatDataset(AudioDataset): |
|
def __init__(self, datasets: list): |
|
self.datasets = datasets |
|
|
|
def __len__(self): |
|
return sum([len(d) for d in self.datasets]) |
|
|
|
def __getitem__(self, idx): |
|
dataset = self.datasets[idx % len(self.datasets)] |
|
return dataset[idx // len(self.datasets)] |
|
|
|
|
|
class ResumableDistributedSampler(DistributedSampler): |
|
"""Distributed sampler that can be resumed from a given start index.""" |
|
|
|
def __init__(self, dataset, start_idx: int = None, **kwargs): |
|
super().__init__(dataset, **kwargs) |
|
|
|
self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0 |
|
|
|
def __iter__(self): |
|
for i, idx in enumerate(super().__iter__()): |
|
if i >= self.start_idx: |
|
yield idx |
|
self.start_idx = 0 |
|
|
|
|
|
class ResumableSequentialSampler(SequentialSampler): |
|
"""Sequential sampler that can be resumed from a given start index.""" |
|
|
|
def __init__(self, dataset, start_idx: int = None, **kwargs): |
|
super().__init__(dataset, **kwargs) |
|
|
|
self.start_idx = start_idx if start_idx is not None else 0 |
|
|
|
def __iter__(self): |
|
for i, idx in enumerate(super().__iter__()): |
|
if i >= self.start_idx: |
|
yield idx |
|
self.start_idx = 0 |
|
|