from pathlib import Path import os import random import torch from torch.utils.data.dataset import Dataset from torchaudio.sox_effects import apply_effects_file from itertools import accumulate class VCC18SegmentalDataset(Dataset): def __init__(self, dataframe, base_path, idtable = '', valid = False): self.base_path = Path(base_path) self.dataframe = dataframe self.segments_durations = 1 if Path.is_file(idtable): self.idtable = torch.load(idtable) for i, judge_i in enumerate(self.dataframe['JUDGE']): self.dataframe['JUDGE'][i] = self.idtable[judge_i] elif not valid: self.gen_idtable(idtable) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): wav_name, mean, mos, judge_id = self.dataframe.loc[idx] wav_path = self.base_path / "Converted_speech_of_submitted_systems" / wav_name wav, _ = apply_effects_file( str(wav_path), [ ["channels", "1"], ["rate", "16000"], ["norm"], ], ) wav = wav.view(-1) wav_segments = unfold_segments(wav, self.segments_durations) system_name = wav_name[:3] + wav_name[-8:-4] return wav_segments, mean, system_name, mos, judge_id def collate_fn(self, samples): wavs_segments, means, system_names, moss, judge_ids = zip(*samples) flattened_wavs_segments = [ wav_segment for wav_segments in wavs_segments for wav_segment in wav_segments ] wav_segments_lengths = [len(wav_segments) for wav_segments in wavs_segments] prefix_sums = list(accumulate(wav_segments_lengths, initial=0)) segment_judge_ids = [] for i in range(len(prefix_sums)-1): segment_judge_ids.extend([judge_ids[i]] * (prefix_sums[i+1]-prefix_sums[i])) return ( torch.stack(flattened_wavs_segments), prefix_sums, torch.FloatTensor(means), system_names, torch.FloatTensor(moss), torch.LongTensor(segment_judge_ids) ) def gen_idtable(self, idtable_path): if idtable_path == '': idtable_path = './idtable.pkl' self.idtable = {} count = 0 for i, judge_i in enumerate(self.dataframe['JUDGE']): if judge_i not in self.idtable.keys(): self.idtable[judge_i] = count count += 1 self.dataframe['JUDGE'][i] = self.idtable[judge_i] else: self.dataframe['JUDGE'][i] = self.idtable[judge_i] torch.save(self.idtable, idtable_path) class VCC16SegmentalDataset(Dataset): def __init__(self, wav_list, base_path): self.wav_dir = Path(base_path) self.wav_list = wav_list self.segments_durations = 1 def __len__(self): return len(self.wav_list) def __getitem__(self, idx): wav_name = self.wav_list[idx] wav_path = self.wav_dir / wav_name wav, _ = apply_effects_file( str(wav_path), [ ["channels", "1"], ["rate", "16000"], ["norm"], ], ) wav = wav.view(-1) wav_segments = unfold_segments(wav, self.segments_durations) system_name = wav_name.name.split("_")[0] return wav_segments, system_name def collate_fn(self, samples): wavs_segments, system_names = zip(*samples) flattened_wavs_segments = [ wav_segment for wav_segments in wavs_segments for wav_segment in wav_segments ] wav_segments_lengths = [len(wav_segments) for wav_segments in wavs_segments] prefix_sums = list(accumulate(wav_segments_lengths, initial=0)) return ( torch.stack(flattened_wavs_segments), prefix_sums, None, system_names, None, None, ) def unfold_segments(tensor, tgt_duration, sample_rate=16000): seg_lengths = int(tgt_duration * sample_rate) src_lengths = len(tensor) step = seg_lengths // 2 tgt_lengths = ( seg_lengths if src_lengths <= seg_lengths else (src_lengths // step + 1) * step ) pad_lengths = tgt_lengths - src_lengths padded_tensor = torch.cat([tensor, torch.zeros(pad_lengths)]) segments = padded_tensor.unfold(0, seg_lengths, step).unbind(0) return segments