|
import torch
|
|
import librosa
|
|
import numpy as np
|
|
import random
|
|
import os
|
|
from torch.utils.data import DataLoader
|
|
from modules.audio import mel_spectrogram
|
|
|
|
|
|
duration_setting = {
|
|
"min": 1.0,
|
|
"max": 30.0,
|
|
}
|
|
|
|
class FT_Dataset(torch.utils.data.Dataset):
|
|
def __init__(self,
|
|
data_path,
|
|
spect_params,
|
|
sr=22050,
|
|
batch_size=1,
|
|
):
|
|
self.data_path = data_path
|
|
|
|
self.data = []
|
|
for root, _, files in os.walk(data_path):
|
|
for file in files:
|
|
if (file.endswith(".wav") or
|
|
file.endswith(".mp3") or
|
|
file.endswith(".flac") or
|
|
file.endswith(".ogg") or
|
|
file.endswith(".m4a") or
|
|
file.endswith(".opus")):
|
|
self.data.append(os.path.join(root, file))
|
|
|
|
mel_fn_args = {
|
|
"n_fft": spect_params['n_fft'],
|
|
"win_size": spect_params['win_length'],
|
|
"hop_size": spect_params['hop_length'],
|
|
"num_mels": spect_params['n_mels'],
|
|
"sampling_rate": sr,
|
|
"fmin": spect_params['fmin'],
|
|
"fmax": None if spect_params['fmax'] == "None" else spect_params['fmax'],
|
|
"center": False
|
|
}
|
|
self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
|
|
self.sr = sr
|
|
|
|
assert len(self.data) != 0
|
|
|
|
while len(self.data) < batch_size:
|
|
self.data += self.data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
idx = idx % len(self.data)
|
|
wav_path = self.data[idx]
|
|
try:
|
|
speech, orig_sr = librosa.load(wav_path, sr=self.sr)
|
|
except Exception as e:
|
|
print(f"Failed to load wav file with error {e}")
|
|
return self.__getitem__(random.randint(0, len(self)))
|
|
if len(speech) < self.sr * duration_setting["min"] or len(speech) > self.sr * duration_setting["max"]:
|
|
print(f"Audio {wav_path} is too short or too long, skipping")
|
|
return self.__getitem__(random.randint(0, len(self)))
|
|
return_dict = {
|
|
'audio': speech,
|
|
'sr': orig_sr
|
|
}
|
|
wave, orig_sr = return_dict['audio'], return_dict['sr']
|
|
if orig_sr != self.sr:
|
|
wave = librosa.resample(wave, orig_sr, self.sr)
|
|
wave = torch.from_numpy(wave).float()
|
|
mel = self.to_mel(wave.unsqueeze(0)).squeeze(0)
|
|
|
|
return wave, mel
|
|
|
|
def build_ft_dataloader(data_path, spect_params, sr, batch_size=1, num_workers=0):
|
|
dataset = FT_Dataset(data_path, spect_params, sr, batch_size)
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
collate_fn=collate,
|
|
)
|
|
return dataloader
|
|
|
|
def collate(batch):
|
|
batch_size = len(batch)
|
|
|
|
|
|
lengths = [b[1].shape[1] for b in batch]
|
|
batch_indexes = np.argsort(lengths)[::-1]
|
|
batch = [batch[bid] for bid in batch_indexes]
|
|
|
|
nmels = batch[0][1].size(0)
|
|
max_mel_length = max([b[1].shape[1] for b in batch])
|
|
max_wave_length = max([b[0].size(0) for b in batch])
|
|
|
|
mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
|
|
waves = torch.zeros((batch_size, max_wave_length)).float()
|
|
|
|
mel_lengths = torch.zeros(batch_size).long()
|
|
wave_lengths = torch.zeros(batch_size).long()
|
|
|
|
for bid, (wave, mel) in enumerate(batch):
|
|
mel_size = mel.size(1)
|
|
mels[bid, :, :mel_size] = mel
|
|
waves[bid, : wave.size(0)] = wave
|
|
mel_lengths[bid] = mel_size
|
|
wave_lengths[bid] = wave.size(0)
|
|
|
|
return waves, mels, wave_lengths, mel_lengths
|
|
|
|
if __name__ == "__main__":
|
|
data_path = "./example/reference"
|
|
sr = 22050
|
|
spect_params = {
|
|
"n_fft": 1024,
|
|
"win_length": 1024,
|
|
"hop_length": 256,
|
|
"n_mels": 80,
|
|
"fmin": 0,
|
|
"fmax": 8000,
|
|
}
|
|
dataloader = build_ft_dataloader(data_path, spect_params, sr, batch_size=2, num_workers=0)
|
|
for idx, batch in enumerate(dataloader):
|
|
wave, mel, wave_lengths, mel_lengths = batch
|
|
print(wave.shape, mel.shape)
|
|
if idx == 10:
|
|
break |