Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,865 Bytes
3ef36df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import random
from TTS.tts.models.xtts import load_audio
torch.set_num_threads(1)
def key_samples_by_col(samples, col):
"""Returns a dictionary of samples keyed by language."""
samples_by_col = {}
for sample in samples:
col_val = sample[col]
assert isinstance(col_val, str)
if col_val not in samples_by_col:
samples_by_col[col_val] = []
samples_by_col[col_val].append(sample)
return samples_by_col
class DVAEDataset(torch.utils.data.Dataset):
def __init__(self, samples, sample_rate, is_eval, max_wav_len=255995):
self.sample_rate = sample_rate
self.is_eval = is_eval
self.max_wav_len = max_wav_len
self.samples = samples
self.training_seed = 1
self.failed_samples = set()
if not is_eval:
random.seed(self.training_seed)
# random.shuffle(self.samples)
random.shuffle(self.samples)
# order by language
self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys())
else:
# for evaluation load and check samples that are corrupted to ensures the reproducibility
self.check_eval_samples()
def check_eval_samples(self):
print(" > Filtering invalid eval samples!!")
new_samples = []
for sample in self.samples:
try:
_, wav = self.load_item(sample)
except:
continue
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
if (
wav is None
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
):
continue
new_samples.append(sample)
self.samples = new_samples
print(" > Total eval samples after filtering:", len(self.samples))
def load_item(self, sample):
audiopath = sample["audio_file"]
wav = load_audio(audiopath, self.sample_rate)
if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
# Ultra short clips are also useless (and can cause problems within some models).
raise ValueError
return audiopath, wav
def __getitem__(self, index):
if self.is_eval:
sample = self.samples[index]
sample_id = str(index)
else:
# select a random language
lang = random.choice(list(self.samples.keys()))
# select random sample
index = random.randint(0, len(self.samples[lang]) - 1)
sample = self.samples[lang][index]
# a unique id for each sampel to deal with fails
sample_id = lang + "_" + str(index)
# ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples:
# call get item again to get other sample
return self[1]
# try to load the sample, if fails added it to the failed samples list
try:
audiopath, wav = self.load_item(sample)
except:
self.failed_samples.add(sample_id)
return self[1]
# check if the audio and text size limits and if it out of the limits, added it failed_samples
if (
wav is None
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
self.failed_samples.add(sample_id)
return self[1]
res = {
"wav": wav,
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
"filenames": audiopath,
}
return res
def __len__(self):
if self.is_eval:
return len(self.samples)
return sum([len(v) for v in self.samples.values()])
def collate_fn(self, batch):
# convert list of dicts to dict of lists
B = len(batch)
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
# stack for features that already have the same shape
batch["wav_lengths"] = torch.stack(batch["wav_lengths"])
max_wav_len = batch["wav_lengths"].max()
# create padding tensors
wav_padded = torch.FloatTensor(B, 1, max_wav_len)
# initialize tensors for zero padding
wav_padded = wav_padded.zero_()
for i in range(B):
wav = batch["wav"][i]
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
batch["wav"] = wav_padded
return batch |