lmzjms's picture
Upload 1162 files
0b32ad6 verified
from pathlib import Path
import os
import random
import torch
from torch.utils.data.dataset import Dataset
from torchaudio.sox_effects import apply_effects_file
from itertools import accumulate
class VCC18SegmentalDataset(Dataset):
def __init__(self, dataframe, base_path, idtable = '', valid = False):
self.base_path = Path(base_path)
self.dataframe = dataframe
self.segments_durations = 1
if Path.is_file(idtable):
self.idtable = torch.load(idtable)
for i, judge_i in enumerate(self.dataframe['JUDGE']):
self.dataframe['JUDGE'][i] = self.idtable[judge_i]
elif not valid:
self.gen_idtable(idtable)
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
wav_name, mean, mos, judge_id = self.dataframe.loc[idx]
wav_path = self.base_path / "Converted_speech_of_submitted_systems" / wav_name
wav, _ = apply_effects_file(
str(wav_path),
[
["channels", "1"],
["rate", "16000"],
["norm"],
],
)
wav = wav.view(-1)
wav_segments = unfold_segments(wav, self.segments_durations)
system_name = wav_name[:3] + wav_name[-8:-4]
return wav_segments, mean, system_name, mos, judge_id
def collate_fn(self, samples):
wavs_segments, means, system_names, moss, judge_ids = zip(*samples)
flattened_wavs_segments = [
wav_segment
for wav_segments in wavs_segments
for wav_segment in wav_segments
]
wav_segments_lengths = [len(wav_segments) for wav_segments in wavs_segments]
prefix_sums = list(accumulate(wav_segments_lengths, initial=0))
segment_judge_ids = []
for i in range(len(prefix_sums)-1):
segment_judge_ids.extend([judge_ids[i]] * (prefix_sums[i+1]-prefix_sums[i]))
return (
torch.stack(flattened_wavs_segments),
prefix_sums,
torch.FloatTensor(means),
system_names,
torch.FloatTensor(moss),
torch.LongTensor(segment_judge_ids)
)
def gen_idtable(self, idtable_path):
if idtable_path == '':
idtable_path = './idtable.pkl'
self.idtable = {}
count = 0
for i, judge_i in enumerate(self.dataframe['JUDGE']):
if judge_i not in self.idtable.keys():
self.idtable[judge_i] = count
count += 1
self.dataframe['JUDGE'][i] = self.idtable[judge_i]
else:
self.dataframe['JUDGE'][i] = self.idtable[judge_i]
torch.save(self.idtable, idtable_path)
class VCC16SegmentalDataset(Dataset):
def __init__(self, wav_list, base_path):
self.wav_dir = Path(base_path)
self.wav_list = wav_list
self.segments_durations = 1
def __len__(self):
return len(self.wav_list)
def __getitem__(self, idx):
wav_name = self.wav_list[idx]
wav_path = self.wav_dir / wav_name
wav, _ = apply_effects_file(
str(wav_path),
[
["channels", "1"],
["rate", "16000"],
["norm"],
],
)
wav = wav.view(-1)
wav_segments = unfold_segments(wav, self.segments_durations)
system_name = wav_name.name.split("_")[0]
return wav_segments, system_name
def collate_fn(self, samples):
wavs_segments, system_names = zip(*samples)
flattened_wavs_segments = [
wav_segment
for wav_segments in wavs_segments
for wav_segment in wav_segments
]
wav_segments_lengths = [len(wav_segments) for wav_segments in wavs_segments]
prefix_sums = list(accumulate(wav_segments_lengths, initial=0))
return (
torch.stack(flattened_wavs_segments),
prefix_sums,
None,
system_names,
None,
None,
)
def unfold_segments(tensor, tgt_duration, sample_rate=16000):
seg_lengths = int(tgt_duration * sample_rate)
src_lengths = len(tensor)
step = seg_lengths // 2
tgt_lengths = (
seg_lengths if src_lengths <= seg_lengths else (src_lengths // step + 1) * step
)
pad_lengths = tgt_lengths - src_lengths
padded_tensor = torch.cat([tensor, torch.zeros(pad_lengths)])
segments = padded_tensor.unfold(0, seg_lengths, step).unbind(0)
return segments