Spaces:
Running
Running
File size: 7,366 Bytes
b971d47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import torch
import random
import copy
import logging
import shutil
class dataset(torch.utils.data.Dataset):
def __init__(self, args, split):
super().__init__()
self.args = args
self.split = split
assert self.split in ['train', 'validation', 'test']
manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt")
with open(manifest_fn, "r") as rf:
data = [l.strip().split("\t") for l in rf.readlines()]
lengths_list = [int(item[-1]) for item in data]
self.data = []
self.lengths_list = []
for d, l in zip(data, lengths_list):
if l >= self.args.encodec_sr*self.args.audio_min_length:
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
continue
self.data.append(d)
self.lengths_list.append(l)
logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}")
# phoneme vocabulary
vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt")
shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt"))
with open(vocab_fn, "r") as f:
temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
self.phn2num = {item[1]:int(item[0]) for item in temp}
self.symbol_set = set(["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"])
def __len__(self):
return len(self.lengths_list)
def _load_phn_enc(self, index):
item = self.data[index]
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt")
try:
with open(pf, "r") as p, open(ef, "r") as e:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"], as they are not in training set annotation
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
assert len(encos) == self.args.n_codebooks, ef
if self.args.special_first:
y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
except Exception as e:
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
logging.info(f"error message: {e}")
return [], [[]]
return x, y
def __getitem__(self, index):
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
if x_len == 0 or y_len == 0:
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
"y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token
"extra_mask_start": None # this is only used in VE1
}
while y_len < self.args.encodec_sr*self.args.audio_min_length:
assert not self.args.dynamic_batching
index = random.choice(range(len(self))) # regenerate an index
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
if self.args.drop_long:
while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length:
index = random.choice(range(len(self))) # regenerate an index
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
### padding and cropping below ###
### padding and cropping below ###
# adjust the length of encodec codes, pad to max_len or randomly crop
orig_y_len = copy.copy(y_len)
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
if y_len > max_len:
audio_start = random.choice(range(0, y_len-max_len))
for i in range(len(y)):
y[i] = y[i][audio_start:(audio_start+max_len)]
y_len = max_len
else:
audio_start = 0
if not self.args.dynamic_batching:
pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
for i in range(len(y)):
y[i] = y[i] + pad
# adjust text
# if audio is cropped, and text is longer than max, crop max based on how audio is cropped
if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started
x = x[int(len(x)*audio_start/orig_y_len):]
if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end
x = x[:self.args.text_max_length]
x_len = len(x)
if x_len > self.args.text_max_length:
text_start = random.choice(range(0, x_len - self.args.text_max_length))
x = x[text_start:text_start+self.args.text_max_length]
x_len = self.args.text_max_length
elif self.args.pad_x and x_len <= self.args.text_max_length:
pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
x = x + pad
### padding and cropping above ###
### padding and cropping above ###
return {
"x": torch.LongTensor(x),
"x_len": x_len,
"y": torch.LongTensor(y),
"y_len": y_len
}
def collate(self, batch):
out = {key:[] for key in batch[0]}
for item in batch:
if item['x'] == None: # deal with load failure
continue
for key, val in item.items():
out[key].append(val)
res = {}
if self.args.pad_x:
res["x"] = torch.stack(out["x"], dim=0)
else:
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.args.dynamic_batching:
if out['y'][0].ndim==2:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
assert out['y'][0].ndim==1, out['y'][0].shape
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.LongTensor(out["y_len"])
res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
return res |