wuxulong19950206
First model version
14d1720
raw
history blame
3.79 kB
import os
import glob
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchaudio
from utils.utils import read_wav_np
def create_dataloader(hp, args, train):
dataset = MelFromDisk(hp, args, train)
if train:
return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True,
num_workers=hp.train.num_workers, pin_memory=True, drop_last=True)
else:
return DataLoader(dataset=dataset, batch_size=1, shuffle=False,
num_workers=0, pin_memory=False, drop_last=False)
class MelFromDisk(Dataset):
def __init__(self, hp, args, train):
self.hp = hp
self.args = args
self.train = train
self.path = hp.data.train if train else hp.data.validation
self.wav_list = glob.glob(os.path.join(self.path, '**', '*.wav'), recursive=True)
#print("Wavs path :", self.path)
#print(self.hp.data.mel_path)
#print("Length of wavelist :", len(self.wav_list))
self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length + 2
self.mapping = [i for i in range(len(self.wav_list))]
def __len__(self):
return len(self.wav_list)
def __getitem__(self, idx):
if self.train:
idx1 = idx
idx2 = self.mapping[idx1]
return self.my_getitem(idx1), self.my_getitem(idx2)
else:
return self.my_getitem(idx)
def shuffle_mapping(self):
random.shuffle(self.mapping)
def my_getitem(self, idx):
wavpath = self.wav_list[idx]
id = os.path.basename(wavpath).split(".")[0]
mel_path = "{}/{}.npy".format(self.hp.data.mel_path, id)
sr, audio = read_wav_np(wavpath)
if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short:
audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \
mode='constant', constant_values=0.0)
audio = torch.from_numpy(audio).unsqueeze(0)
# mel = torch.load(melpath).squeeze(0) # # [num_mel, T]
mel = torch.from_numpy(np.load(mel_path))
if self.train:
max_mel_start = mel.size(1) - self.mel_segment_length
mel_start = random.randint(0, max_mel_start)
mel_end = mel_start + self.mel_segment_length
mel = mel[:, mel_start:mel_end]
audio_start = mel_start * self.hp.audio.hop_length
audio = audio[:, audio_start:audio_start+self.hp.audio.segment_length]
audio = audio + (1/32768) * torch.randn_like(audio)
return mel, audio
def collate_fn(batch):
sr = 22050
# perform padding and conversion to tensor
mels_g = [x[0][0] for x in batch]
audio_g = [x[0][1] for x in batch]
mels_g = torch.stack(mels_g)
audio_g = torch.stack(audio_g)
sub_orig_1 = torchaudio.transforms.Resample(sr, (sr // 2))(audio_g)
sub_orig_2 = torchaudio.transforms.Resample(sr, (sr // 4))(audio_g)
sub_orig_3 = torchaudio.transforms.Resample(sr, (sr // 8))(audio_g)
sub_orig_4 = torchaudio.transforms.Resample(sr, (sr // 16))(audio_g)
mels_d = [x[1][0] for x in batch]
audio_d = [x[1][1] for x in batch]
mels_d = torch.stack(mels_d)
audio_d = torch.stack(audio_d)
sub_orig_1_d = torchaudio.transforms.Resample(sr, (sr // 2))(audio_d)
sub_orig_2_d = torchaudio.transforms.Resample(sr, (sr // 4))(audio_d)
sub_orig_3_d = torchaudio.transforms.Resample(sr, (sr // 8))(audio_d)
sub_orig_4_d = torchaudio.transforms.Resample(sr, (sr // 16))(audio_d)
return [mels_g, audio_g, sub_orig_1, sub_orig_2, sub_orig_3, sub_orig_4],\
[mels_d, audio_d, sub_orig_1_d, sub_orig_2_d, sub_orig_3_d, sub_orig_4_d]