|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial
|
|
from itertools import product
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import typing as tp
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from audiocraft.data.audio_dataset import (
|
|
AudioDataset,
|
|
AudioMeta,
|
|
_get_audio_meta,
|
|
load_audio_meta,
|
|
save_audio_meta
|
|
)
|
|
from audiocraft.data.zip import PathInZip
|
|
|
|
from ..common_utils import TempDirMixin, get_white_noise, save_wav
|
|
|
|
|
|
class TestAudioMeta(TempDirMixin):
|
|
|
|
def test_get_audio_meta(self):
|
|
sample_rates = [8000, 16_000]
|
|
channels = [1, 2]
|
|
duration = 1.
|
|
for sample_rate, ch in product(sample_rates, channels):
|
|
n_frames = int(duration * sample_rate)
|
|
wav = get_white_noise(ch, n_frames)
|
|
path = self.get_temp_path('sample.wav')
|
|
save_wav(path, wav, sample_rate)
|
|
m = _get_audio_meta(path, minimal=True)
|
|
assert m.path == path, 'path does not match'
|
|
assert m.sample_rate == sample_rate, 'sample rate does not match'
|
|
assert m.duration == duration, 'duration does not match'
|
|
assert m.amplitude is None
|
|
assert m.info_path is None
|
|
|
|
def test_save_audio_meta(self):
|
|
audio_meta = [
|
|
AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
|
|
AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
|
|
]
|
|
empty_audio_meta = []
|
|
for idx, meta in enumerate([audio_meta, empty_audio_meta]):
|
|
path = self.get_temp_path(f'data_{idx}_save.jsonl')
|
|
save_audio_meta(path, meta)
|
|
with open(path, 'r') as f:
|
|
lines = f.readlines()
|
|
read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines]
|
|
assert len(read_meta) == len(meta)
|
|
for m, read_m in zip(meta, read_meta):
|
|
assert m == read_m
|
|
|
|
def test_load_audio_meta(self):
|
|
try:
|
|
import dora
|
|
except ImportError:
|
|
dora = None
|
|
|
|
audio_meta = [
|
|
AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
|
|
AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
|
|
]
|
|
empty_meta = []
|
|
for idx, meta in enumerate([audio_meta, empty_meta]):
|
|
path = self.get_temp_path(f'data_{idx}_load.jsonl')
|
|
with open(path, 'w') as f:
|
|
for m in meta:
|
|
json_str = json.dumps(m.to_dict()) + '\n'
|
|
f.write(json_str)
|
|
read_meta = load_audio_meta(path)
|
|
assert len(read_meta) == len(meta)
|
|
for m, read_m in zip(meta, read_meta):
|
|
if dora:
|
|
m.path = dora.git_save.to_absolute_path(m.path)
|
|
assert m == read_m, f'original={m}, read={read_m}'
|
|
|
|
|
|
class TestAudioDataset(TempDirMixin):
|
|
|
|
def _create_audio_files(self,
|
|
root_name: str,
|
|
num_examples: int,
|
|
durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
|
|
sample_rate: int = 16_000,
|
|
channels: int = 1):
|
|
root_dir = self.get_temp_dir(root_name)
|
|
for i in range(num_examples):
|
|
if isinstance(durations, float):
|
|
duration = durations
|
|
elif isinstance(durations, tuple) and len(durations) == 1:
|
|
duration = durations[0]
|
|
elif isinstance(durations, tuple) and len(durations) == 2:
|
|
duration = random.uniform(durations[0], durations[1])
|
|
else:
|
|
assert False
|
|
n_frames = int(duration * sample_rate)
|
|
wav = get_white_noise(channels, n_frames)
|
|
path = os.path.join(root_dir, f'example_{i}.wav')
|
|
save_wav(path, wav, sample_rate)
|
|
return root_dir
|
|
|
|
def _create_audio_dataset(self,
|
|
root_name: str,
|
|
total_num_examples: int,
|
|
durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
|
|
sample_rate: int = 16_000,
|
|
channels: int = 1,
|
|
segment_duration: tp.Optional[float] = None,
|
|
num_examples: int = 10,
|
|
shuffle: bool = True,
|
|
return_info: bool = False):
|
|
root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels)
|
|
dataset = AudioDataset.from_path(root_dir,
|
|
minimal_meta=True,
|
|
segment_duration=segment_duration,
|
|
num_samples=num_examples,
|
|
sample_rate=sample_rate,
|
|
channels=channels,
|
|
shuffle=shuffle,
|
|
return_info=return_info)
|
|
return dataset
|
|
|
|
def test_dataset_full(self):
|
|
total_examples = 10
|
|
min_duration, max_duration = 1., 4.
|
|
sample_rate = 16_000
|
|
channels = 1
|
|
dataset = self._create_audio_dataset(
|
|
'dset', total_examples, durations=(min_duration, max_duration),
|
|
sample_rate=sample_rate, channels=channels, segment_duration=None)
|
|
assert len(dataset) == total_examples
|
|
assert dataset.sample_rate == sample_rate
|
|
assert dataset.channels == channels
|
|
for idx in range(len(dataset)):
|
|
sample = dataset[idx]
|
|
assert sample.shape[0] == channels
|
|
assert sample.shape[1] <= int(max_duration * sample_rate)
|
|
assert sample.shape[1] >= int(min_duration * sample_rate)
|
|
|
|
def test_dataset_segment(self):
|
|
total_examples = 10
|
|
num_samples = 20
|
|
min_duration, max_duration = 1., 4.
|
|
segment_duration = 1.
|
|
sample_rate = 16_000
|
|
channels = 1
|
|
dataset = self._create_audio_dataset(
|
|
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
|
channels=channels, segment_duration=segment_duration, num_examples=num_samples)
|
|
assert len(dataset) == num_samples
|
|
assert dataset.sample_rate == sample_rate
|
|
assert dataset.channels == channels
|
|
for idx in range(len(dataset)):
|
|
sample = dataset[idx]
|
|
assert sample.shape[0] == channels
|
|
assert sample.shape[1] == int(segment_duration * sample_rate)
|
|
|
|
def test_dataset_equal_audio_and_segment_durations(self):
|
|
total_examples = 1
|
|
num_samples = 2
|
|
audio_duration = 1.
|
|
segment_duration = 1.
|
|
sample_rate = 16_000
|
|
channels = 1
|
|
dataset = self._create_audio_dataset(
|
|
'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
|
|
channels=channels, segment_duration=segment_duration, num_examples=num_samples)
|
|
assert len(dataset) == num_samples
|
|
assert dataset.sample_rate == sample_rate
|
|
assert dataset.channels == channels
|
|
for idx in range(len(dataset)):
|
|
sample = dataset[idx]
|
|
assert sample.shape[0] == channels
|
|
assert sample.shape[1] == int(segment_duration * sample_rate)
|
|
|
|
sample_1 = dataset[0]
|
|
sample_2 = dataset[1]
|
|
assert not torch.allclose(sample_1, sample_2)
|
|
|
|
def test_dataset_samples(self):
|
|
total_examples = 1
|
|
num_samples = 2
|
|
audio_duration = 1.
|
|
segment_duration = 1.
|
|
sample_rate = 16_000
|
|
channels = 1
|
|
|
|
create_dataset = partial(
|
|
self._create_audio_dataset,
|
|
'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
|
|
channels=channels, segment_duration=segment_duration, num_examples=num_samples,
|
|
)
|
|
|
|
dataset = create_dataset(shuffle=True)
|
|
|
|
sample_1 = dataset[0]
|
|
sample_2 = dataset[0]
|
|
assert not torch.allclose(sample_1, sample_2)
|
|
|
|
dataset_noshuffle = create_dataset(shuffle=False)
|
|
|
|
sample_1 = dataset_noshuffle[0]
|
|
sample_2 = dataset_noshuffle[0]
|
|
assert torch.allclose(sample_1, sample_2)
|
|
|
|
def test_dataset_return_info(self):
|
|
total_examples = 10
|
|
num_samples = 20
|
|
min_duration, max_duration = 1., 4.
|
|
segment_duration = 1.
|
|
sample_rate = 16_000
|
|
channels = 1
|
|
dataset = self._create_audio_dataset(
|
|
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
|
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
|
assert len(dataset) == num_samples
|
|
assert dataset.sample_rate == sample_rate
|
|
assert dataset.channels == channels
|
|
for idx in range(len(dataset)):
|
|
sample, segment_info = dataset[idx]
|
|
assert sample.shape[0] == channels
|
|
assert sample.shape[1] == int(segment_duration * sample_rate)
|
|
assert segment_info.sample_rate == sample_rate
|
|
assert segment_info.total_frames == int(segment_duration * sample_rate)
|
|
assert segment_info.n_frames <= int(segment_duration * sample_rate)
|
|
assert segment_info.seek_time >= 0
|
|
|
|
def test_dataset_return_info_no_segment_duration(self):
|
|
total_examples = 10
|
|
num_samples = 20
|
|
min_duration, max_duration = 1., 4.
|
|
segment_duration = None
|
|
sample_rate = 16_000
|
|
channels = 1
|
|
dataset = self._create_audio_dataset(
|
|
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
|
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
|
assert len(dataset) == total_examples
|
|
assert dataset.sample_rate == sample_rate
|
|
assert dataset.channels == channels
|
|
for idx in range(len(dataset)):
|
|
sample, segment_info = dataset[idx]
|
|
assert sample.shape[0] == channels
|
|
assert sample.shape[1] == segment_info.total_frames
|
|
assert segment_info.sample_rate == sample_rate
|
|
assert segment_info.n_frames <= segment_info.total_frames
|
|
|
|
def test_dataset_collate_fn(self):
|
|
total_examples = 10
|
|
num_samples = 20
|
|
min_duration, max_duration = 1., 4.
|
|
segment_duration = 1.
|
|
sample_rate = 16_000
|
|
channels = 1
|
|
dataset = self._create_audio_dataset(
|
|
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
|
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False)
|
|
batch_size = 4
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
num_workers=0
|
|
)
|
|
for idx, batch in enumerate(dataloader):
|
|
assert batch.shape[0] == batch_size
|
|
|
|
@pytest.mark.parametrize("segment_duration", [1.0, None])
|
|
def test_dataset_with_meta_collate_fn(self, segment_duration):
|
|
total_examples = 10
|
|
num_samples = 20
|
|
min_duration, max_duration = 1., 4.
|
|
segment_duration = 1.
|
|
sample_rate = 16_000
|
|
channels = 1
|
|
dataset = self._create_audio_dataset(
|
|
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
|
|
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
|
|
batch_size = 4
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=dataset.collater,
|
|
num_workers=0
|
|
)
|
|
for idx, batch in enumerate(dataloader):
|
|
wav, infos = batch
|
|
assert wav.shape[0] == batch_size
|
|
assert len(infos) == batch_size
|
|
|
|
@pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [
|
|
[1, True, True, 0.5, 0.5, 0.0],
|
|
[1, False, True, 0.25, 0.5, 0.25],
|
|
[1, True, False, 0.666, 0.333, 0.0],
|
|
[1, False, False, 0.333, 0.333, 0.333],
|
|
[None, False, False, 0.333, 0.333, 0.333]])
|
|
def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist):
|
|
random.seed(1234)
|
|
rng = torch.Generator()
|
|
rng.manual_seed(1234)
|
|
|
|
def _get_histogram(dataset, repetitions=20_000):
|
|
counts = {file_meta.path: 0. for file_meta in meta}
|
|
for _ in range(repetitions):
|
|
file_meta = dataset.sample_file(0, rng)
|
|
counts[file_meta.path] += 1
|
|
return {name: count / repetitions for name, count in counts.items()}
|
|
|
|
meta = [
|
|
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
|
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
|
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
|
]
|
|
dataset = AudioDataset(
|
|
meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight,
|
|
sample_on_duration=sample_on_duration)
|
|
hist = _get_histogram(dataset)
|
|
assert math.isclose(hist['a'], a_hist, abs_tol=0.01)
|
|
assert math.isclose(hist['b'], b_hist, abs_tol=0.01)
|
|
assert math.isclose(hist['c'], c_hist, abs_tol=0.01)
|
|
|
|
def test_meta_duration_filter_all(self):
|
|
meta = [
|
|
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
|
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
|
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
|
]
|
|
try:
|
|
AudioDataset(meta, segment_duration=11, min_segment_ratio=1)
|
|
assert False
|
|
except AssertionError:
|
|
assert True
|
|
|
|
def test_meta_duration_filter_long(self):
|
|
meta = [
|
|
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
|
|
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
|
|
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
|
|
]
|
|
dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7)
|
|
assert len(dataset) == 2
|
|
|