|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import random |
|
|
|
import pandas as pd |
|
|
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from torch.utils.data.dataset import Dataset |
|
|
|
import torchaudio |
|
|
|
|
|
HALF_BATCHSIZE_TIME = 2000 |
|
|
|
|
|
|
|
|
|
|
|
class PhoneDataset(Dataset): |
|
|
|
def __init__(self, split, bucket_size, libri_root, phone_path, bucket_file, sample_rate=16000, train_dev_seed=1337, **kwargs): |
|
super(PhoneDataset, self).__init__() |
|
|
|
self.libri_root = libri_root |
|
self.phone_path = phone_path |
|
self.sample_rate = sample_rate |
|
self.class_num = 41 |
|
|
|
self.Y = {} |
|
phone_file = open(os.path.join(phone_path, 'converted_aligned_phones.txt')).readlines() |
|
for line in phone_file: |
|
line = line.strip('\n').split(' ') |
|
self.Y[line[0]] = [int(p) for p in line[1:]] |
|
|
|
if split == 'train' or split == 'dev': |
|
usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines() |
|
random.seed(train_dev_seed) |
|
random.shuffle(usage_list) |
|
percent = int(len(usage_list)*0.9) |
|
usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:] |
|
elif split == 'test': |
|
usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines() |
|
else: |
|
raise ValueError('Invalid \'split\' argument for dataset: PhoneDataset!') |
|
usage_list = {line.strip('\n'):None for line in usage_list} |
|
print('[Dataset] - # phone classes: ' + str(self.class_num) + ', number of data for ' + split + ': ' + str(len(usage_list))) |
|
|
|
|
|
assert os.path.isdir(bucket_file), 'Please first run `preprocess/generate_len_for_bucket.py to get bucket file.' |
|
table = pd.read_csv(os.path.join(bucket_file, 'train-clean-100.csv')).sort_values(by=['length'], ascending=False) |
|
X = table['file_path'].tolist() |
|
X_lens = table['length'].tolist() |
|
|
|
|
|
self.X = [] |
|
batch_x, batch_len = [], [] |
|
|
|
for x, x_len in zip(X, X_lens): |
|
if self._parse_x_name(x) in usage_list: |
|
batch_x.append(x) |
|
batch_len.append(x_len) |
|
|
|
|
|
if len(batch_x) == bucket_size: |
|
|
|
if (bucket_size >= 2) and (max(batch_len) > HALF_BATCHSIZE_TIME): |
|
self.X.append(batch_x[:bucket_size//2]) |
|
self.X.append(batch_x[bucket_size//2:]) |
|
else: |
|
self.X.append(batch_x) |
|
batch_x, batch_len = [], [] |
|
|
|
|
|
if len(batch_x) > 1: |
|
if self._parse_x_name(x) in usage_list: |
|
self.X.append(batch_x) |
|
|
|
def _parse_x_name(self, x): |
|
return x.split('/')[-1].split('.')[0] |
|
|
|
def _load_wav(self, wav_path): |
|
wav, sr = torchaudio.load(os.path.join(self.libri_root, wav_path)) |
|
|
|
return wav.view(-1) |
|
|
|
def __len__(self): |
|
return len(self.X) |
|
|
|
def __getitem__(self, index): |
|
|
|
wav_batch = [self._load_wav(x_file) for x_file in self.X[index]] |
|
label_batch = [torch.LongTensor(self.Y[self._parse_x_name(x_file)]) for x_file in self.X[index]] |
|
return wav_batch, label_batch |
|
|
|
def collate_fn(self, items): |
|
return items[0][0], items[0][1] |
|
|