lmzjms's picture
Upload 1162 files
0b32ad6 verified
# -*- coding: utf-8 -*- #
# FileName [ dataset.py ]
# Synopsis [ the phone dataset ]
# Author [ S3PRL ]
# Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ]
import os
import random
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataset import Dataset
import torchaudio
# Phone Dataset #
class PhoneDataset(Dataset):
def __init__(self, split, bucket_size, libri_root, phone_path, bucket_file, sample_rate=16000, train_dev_seed=1337, **kwargs):
super(PhoneDataset, self).__init__()
self.libri_root = libri_root
self.phone_path = phone_path
self.sample_rate = sample_rate
self.class_num = 41 # NOTE: pre-computed, should not need change
self.Y = {}
phone_file = open(os.path.join(phone_path, 'converted_aligned_phones.txt')).readlines()
for line in phone_file:
line = line.strip('\n').split(' ')
self.Y[line[0]] = [int(p) for p in line[1:]]
if split == 'train' or split == 'dev':
usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()
percent = int(len(usage_list)*0.9)
usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]
elif split == 'test':
usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()
raise ValueError('Invalid \'split\' argument for dataset: PhoneDataset!')
usage_list = {line.strip('\n'):None for line in usage_list}
print('[Dataset] - # phone classes: ' + str(self.class_num) + ', number of data for ' + split + ': ' + str(len(usage_list)))
# Read table for bucketing
assert os.path.isdir(bucket_file), 'Please first run `preprocess/generate_len_for_bucket.py to get bucket file.'
table = pd.read_csv(os.path.join(bucket_file, 'train-clean-100.csv')).sort_values(by=['length'], ascending=False)
X = table['file_path'].tolist()
X_lens = table['length'].tolist()
# Use bucketing to allow different batch sizes at run time
self.X = []
batch_x, batch_len = [], []
for x, x_len in zip(X, X_lens):
if self._parse_x_name(x) in usage_list:
# Fill in batch_x until batch is full
if len(batch_x) == bucket_size:
# Half the batch size if seq too long
if (bucket_size >= 2) and (max(batch_len) > HALF_BATCHSIZE_TIME):
batch_x, batch_len = [], []
# Gather the last batch
if len(batch_x) > 1:
if self._parse_x_name(x) in usage_list:
def _parse_x_name(self, x):
return x.split('/')[-1].split('.')[0]
def _load_wav(self, wav_path):
wav, sr = torchaudio.load(os.path.join(self.libri_root, wav_path))
# assert sr == self.sample_rate, f'Sample rate mismatch: real {sr}, config {self.sample_rate}'
return wav.view(-1)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
# Load acoustic feature and pad
wav_batch = [self._load_wav(x_file) for x_file in self.X[index]]
label_batch = [torch.LongTensor(self.Y[self._parse_x_name(x_file)]) for x_file in self.X[index]]
return wav_batch, label_batch # bucketing, return ((wavs, labels))
def collate_fn(self, items):
return items[0][0], items[0][1] # hack bucketing, return (wavs, labels)