Spaces:
Runtime error
Runtime error
# Copyright 2023 (authors: Feiteng Li) | |
# | |
# See ../../../../LICENSE for clarification regarding multiple authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
modified from lhoste.dataset.speech_synthesis.py | |
""" | |
import torch | |
import math | |
import h5py | |
from tokenizers import Tokenizer | |
from typing import Union, List | |
import numpy as np | |
from tqdm import tqdm | |
_pad = '_' | |
_punctuation = ',.!?-~…' | |
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' | |
symbols = [_pad] + list(_punctuation) + list(_letters) | |
language_dict = { | |
'en': 0, | |
'zh': 1, | |
'ja': 2, | |
} | |
def seq2phone(tokens: Union[List, np.ndarray]): | |
""" | |
Convert tokenized phoneme ID sequence back to phoneme string | |
:param tokens: phoneme tokens | |
:return: recovered phoneme sequence | |
""" | |
phones = "".join([symbols[i] for i in tokens]) | |
return phones | |
class DynamicBatchSampler(torch.utils.data.Sampler): | |
def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000, | |
max_tokens=None, max_sentences=None, drop_last=False): | |
""" | |
:param sampler: | |
:param num_tokens_fn: 根据idx返回样本的长度的函数 | |
:param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量 | |
:param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶 | |
:param max_size: 最大长度的样本 | |
:param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小 | |
""" | |
super(DynamicBatchSampler, self).__init__(sampler) | |
self.sampler = sampler | |
self.num_tokens_fn = num_tokens_fn | |
self.num_buckets = num_buckets | |
self.min_size = min_size | |
self.max_size = max_size | |
assert max_size <= max_tokens, "max_size should be smaller than max tokens" | |
assert max_tokens is not None or max_sentences is not None, \ | |
"max_tokens and max_sentences should not be null at the same time, please specify one parameter at least" | |
self.max_tokens = max_tokens if max_tokens is not None else float('Inf') | |
self.max_sentences = max_sentences if max_sentences is not None else float('Inf') | |
self.drop_last = drop_last | |
def set_epoch(self, epoch): | |
self.sampler.set_epoch(epoch) | |
def is_batch_full(self, num_tokens, batch): | |
if len(batch) == 0: | |
return False | |
if len(batch) == self.max_sentences: | |
return True | |
if num_tokens > self.max_tokens: | |
return True | |
return False | |
def __iter__(self): | |
buckets = [[] for _ in range(self.num_buckets)] | |
sample_len = [0] * self.num_buckets | |
for idx in self.sampler: | |
idx_length = self.num_tokens_fn(idx) | |
if not (self.min_size <= idx_length <= self.max_size): | |
print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length)) | |
continue | |
index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1) | |
* self.num_buckets) | |
sample_len[index_buckets] = max(sample_len[index_buckets], idx_length) | |
num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets] | |
if self.is_batch_full(num_tokens, buckets[index_buckets]): | |
# yield this batch | |
yield buckets[index_buckets] | |
buckets[index_buckets] = [] | |
sample_len[index_buckets] = 0 | |
buckets[index_buckets].append(idx) | |
# process left-over | |
leftover_batch = [] | |
leftover_sample_len = 0 | |
leftover = [idx for bucket in buckets for idx in bucket] | |
for idx in leftover: | |
idx_length = self.num_tokens_fn(idx) | |
leftover_sample_len = max(leftover_sample_len, idx_length) | |
num_tokens = (len(leftover_batch) + 1) * leftover_sample_len | |
if self.is_batch_full(num_tokens, leftover_batch): | |
yield leftover_batch | |
leftover_batch = [] | |
leftover_sample_len = 0 | |
leftover_batch.append(idx) | |
if len(leftover_batch) > 0 and not self.drop_last: | |
yield leftover_batch | |
def __len__(self): | |
# we do not know the exactly batch size, so do not call len(dataloader) | |
pass | |
class AudioDataset(torch.utils.data.Dataset): | |
def __init__(self, h5_path, ann_path, tokenizer_path): | |
self.h5_path = h5_path | |
with open(ann_path, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
ls = [l.split("|") for l in lines] | |
ls_T = list(zip(*ls)) | |
del ls_T[-1] | |
self.h5_paths, self.durations, self.langs, self.texts = \ | |
list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3]) | |
self.durations = [float(dur) for dur in self.durations] | |
self.tokenizer = Tokenizer.from_file(tokenizer_path) | |
self._archive = None | |
def __len__(self): | |
return len(self.h5_paths) | |
def get_dur(self, idx): | |
return self.durations[idx] | |
def archive(self): | |
if self._archive is None: # lazy loading here! | |
self._archive = h5py.File(self.h5_path, "r") | |
return self._archive | |
def __getitem__(self, idx): | |
archive = self.archive | |
h5_path = self.h5_paths[idx] | |
sub = archive[h5_path] | |
audio_tokens = sub['audio'][()] | |
phone_tokens = sub['text'][()] | |
dur = self.durations[idx] | |
lang = self.langs[idx] | |
text = self.texts[idx] | |
# tokenization should be done within dataloader | |
phones = seq2phone(phone_tokens) | |
phones = phones.replace(" ", "_") | |
if not len(phones): | |
cptpho_tokens = self.tokenizer.encode(text).ids | |
else: | |
cptpho_tokens = self.tokenizer.encode(phones).ids | |
assert len(cptpho_tokens) | |
return { | |
'utt_id': h5_path, | |
'text': text, | |
'audio': None, | |
'audio_lens': None, | |
'audio_features': audio_tokens, | |
'audio_features_lens': len(audio_tokens.T), | |
'text_tokens': np.array(cptpho_tokens), | |
'text_tokens_lens': len(cptpho_tokens), | |
'language': language_dict[lang], | |
} | |
def collate(batch): | |
utt_id_s = [b['utt_id'] for b in batch] | |
text_s = [b['text'] for b in batch] | |
audio_s = [b['audio'] for b in batch] | |
audio_lens_s = [b['audio_lens'] for b in batch] | |
audio_features_lens_s = [b['audio_features_lens'] for b in batch] | |
# create an empty tensor with maximum audio feature length | |
audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1 | |
text_tokens_lens_s = [b['text_tokens_lens'] for b in batch] | |
# create an empty tensor with maximum text tokens length | |
text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3 | |
language_s = [b['language'] for b in batch] | |
for i, b in enumerate(batch): | |
audio_features = b['audio_features'] | |
audio_features_lens = b['audio_features_lens'] | |
audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T) | |
text_tokens = b['text_tokens'] | |
text_tokens_lens = b['text_tokens_lens'] | |
text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens) | |
batch = { | |
'utt_id': utt_id_s, | |
'text': text_s, | |
'audio': audio_s, | |
'audio_lens': audio_lens_s, | |
'audio_features': audio_features_s, | |
'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)), | |
'text_tokens': text_tokens_s, | |
'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)), | |
'languages': torch.LongTensor(np.array(language_s)), | |
} | |
return batch | |
def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120): | |
train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5", | |
ann_path=f"{data_dir}/audio_ann_sum.txt", | |
tokenizer_path=f"{data_dir}/bpe_69.json") | |
ran_sampler = torch.utils.data.distributed.DistributedSampler( | |
train_dataset, | |
num_replicas=n_gpus, | |
rank=rank, | |
shuffle=True, | |
) | |
dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20, | |
max_tokens=max_duration) | |
train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate, | |
batch_sampler=dynamic_sampler) | |
return train_loader | |