File size: 4,581 Bytes
20d6bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import librosa
import numpy as np
import random
import os
from torch.utils.data import DataLoader
from modules.audio import mel_spectrogram


duration_setting = {
    "min": 1.0,
    "max": 30.0,
}
# assume single speaker
class FT_Dataset(torch.utils.data.Dataset):
    def __init__(self,

                 data_path,

                 spect_params,

                 sr=22050,

                 batch_size=1,

                 ):
        self.data_path = data_path
        # recursively find all files in data_path
        self.data = []
        for root, _, files in os.walk(data_path):
            for file in files:
                if (file.endswith(".wav") or
                        file.endswith(".mp3") or
                        file.endswith(".flac") or
                        file.endswith(".ogg") or
                        file.endswith(".m4a") or
                        file.endswith(".opus")):
                    self.data.append(os.path.join(root, file))

        mel_fn_args = {
            "n_fft": spect_params['n_fft'],
            "win_size": spect_params['win_length'],
            "hop_size": spect_params['hop_length'],
            "num_mels": spect_params['n_mels'],
            "sampling_rate": sr,
            "fmin": spect_params['fmin'],
            "fmax": None if spect_params['fmax'] == "None" else spect_params['fmax'],
            "center": False
        }
        self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
        self.sr = sr

        assert len(self.data) != 0
        # if dataset length is less than batch size, repeat the dataset
        while len(self.data) < batch_size:
            self.data += self.data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        idx = idx % len(self.data)
        wav_path = self.data[idx]
        try:
            speech, orig_sr = librosa.load(wav_path, sr=self.sr)
        except Exception as e:
            print(f"Failed to load wav file with error {e}")
            return self.__getitem__(random.randint(0, len(self)))
        if len(speech) < self.sr * duration_setting["min"] or len(speech) > self.sr * duration_setting["max"]:
            print(f"Audio {wav_path} is too short or too long, skipping")
            return self.__getitem__(random.randint(0, len(self)))
        return_dict = {
            'audio': speech,
            'sr': orig_sr
        }
        wave, orig_sr = return_dict['audio'], return_dict['sr']
        if orig_sr != self.sr:
            wave = librosa.resample(wave, orig_sr, self.sr)
        wave = torch.from_numpy(wave).float()
        mel = self.to_mel(wave.unsqueeze(0)).squeeze(0)

        return wave, mel

def build_ft_dataloader(data_path, spect_params, sr, batch_size=1, num_workers=0):
    dataset = FT_Dataset(data_path, spect_params, sr, batch_size)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate,
    )
    return dataloader

def collate(batch):
    batch_size = len(batch)

    # sort by mel length
    lengths = [b[1].shape[1] for b in batch]
    batch_indexes = np.argsort(lengths)[::-1]
    batch = [batch[bid] for bid in batch_indexes]

    nmels = batch[0][1].size(0)
    max_mel_length = max([b[1].shape[1] for b in batch])
    max_wave_length = max([b[0].size(0) for b in batch])

    mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
    waves = torch.zeros((batch_size, max_wave_length)).float()

    mel_lengths = torch.zeros(batch_size).long()
    wave_lengths = torch.zeros(batch_size).long()

    for bid, (wave, mel) in enumerate(batch):
        mel_size = mel.size(1)
        mels[bid, :, :mel_size] = mel
        waves[bid, : wave.size(0)] = wave
        mel_lengths[bid] = mel_size
        wave_lengths[bid] = wave.size(0)

    return waves, mels, wave_lengths, mel_lengths

if __name__ == "__main__":
    data_path = "./example/reference"
    sr = 22050
    spect_params = {
        "n_fft": 1024,
        "win_length": 1024,
        "hop_length": 256,
        "n_mels": 80,
        "fmin": 0,
        "fmax": 8000,
    }
    dataloader = build_ft_dataloader(data_path, spect_params, sr, batch_size=2, num_workers=0)
    for idx, batch in enumerate(dataloader):
        wave, mel, wave_lengths, mel_lengths = batch
        print(wave.shape, mel.shape)
        if idx == 10:
            break