|
import os |
|
import random |
|
|
|
import librosa |
|
import soundfile as sf |
|
import torch |
|
from speechbrain.pretrained import EncoderClassifier |
|
from torch.multiprocessing import Manager |
|
from torch.multiprocessing import Process |
|
from torch.utils.data import Dataset |
|
from torchaudio.transforms import Resample |
|
from tqdm import tqdm |
|
|
|
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor |
|
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend |
|
from Utility.storage_config import MODELS_DIR |
|
|
|
|
|
class CodecAlignerDataset(Dataset): |
|
|
|
def __init__(self, |
|
path_to_transcript_dict, |
|
cache_dir, |
|
lang, |
|
loading_processes, |
|
device, |
|
min_len_in_seconds=1, |
|
max_len_in_seconds=15, |
|
rebuild_cache=False, |
|
verbose=False, |
|
phone_input=False, |
|
allow_unknown_symbols=False, |
|
gpu_count=1, |
|
rank=0): |
|
|
|
self.gpu_count = gpu_count |
|
self.rank = rank |
|
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: |
|
self._build_dataset_cache(path_to_transcript_dict=path_to_transcript_dict, |
|
cache_dir=cache_dir, |
|
lang=lang, |
|
loading_processes=loading_processes, |
|
device=device, |
|
min_len_in_seconds=min_len_in_seconds, |
|
max_len_in_seconds=max_len_in_seconds, |
|
verbose=verbose, |
|
phone_input=phone_input, |
|
allow_unknown_symbols=allow_unknown_symbols, |
|
gpu_count=gpu_count, |
|
rank=rank) |
|
self.lang = lang |
|
self.device = device |
|
self.cache_dir = cache_dir |
|
self.tf = ArticulatoryCombinedTextFrontend(language=self.lang, device=device) |
|
cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu') |
|
self.speaker_embeddings = cache[2] |
|
self.filepaths = cache[3] |
|
self.datapoints = cache[0] |
|
if self.gpu_count > 1: |
|
|
|
while len(self.datapoints) % self.gpu_count != 0: |
|
self.datapoints.pop(-1) |
|
chunksize = int(len(self.datapoints) / self.gpu_count) |
|
self.datapoints = self.datapoints[chunksize * self.rank:chunksize * (self.rank + 1)] |
|
self.speaker_embeddings = self.speaker_embeddings[chunksize * self.rank:chunksize * (self.rank + 1)] |
|
print(f"Loaded an Aligner dataset with {len(self.datapoints)} datapoints from {cache_dir}.") |
|
|
|
def _build_dataset_cache(self, |
|
path_to_transcript_dict, |
|
cache_dir, |
|
lang, |
|
loading_processes, |
|
device, |
|
min_len_in_seconds=1, |
|
max_len_in_seconds=15, |
|
verbose=False, |
|
phone_input=False, |
|
allow_unknown_symbols=False, |
|
gpu_count=1, |
|
rank=0 |
|
): |
|
if gpu_count != 1: |
|
import sys |
|
print("Please run the feature extraction using only a single GPU. Multi-GPU is only supported for training.") |
|
sys.exit() |
|
os.makedirs(cache_dir, exist_ok=True) |
|
if type(path_to_transcript_dict) != dict: |
|
path_to_transcript_dict = path_to_transcript_dict() |
|
torch.multiprocessing.set_start_method('spawn', force=True) |
|
torch.multiprocessing.set_sharing_strategy('file_system') |
|
resource_manager = Manager() |
|
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict) |
|
key_list = list(self.path_to_transcript_dict.keys()) |
|
with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note: |
|
files_used_note.write(str(key_list)) |
|
fisher_yates_shuffle(key_list) |
|
|
|
print("... building dataset cache ...") |
|
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
|
|
|
_, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', |
|
model='silero_vad', |
|
force_reload=False, |
|
onnx=False, |
|
verbose=False) |
|
self.result_pool = resource_manager.list() |
|
|
|
key_splits = list() |
|
process_list = list() |
|
for i in range(loading_processes): |
|
key_splits.append( |
|
key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes]) |
|
for key_split in key_splits: |
|
process_list.append( |
|
Process(target=self._cache_builder_process, |
|
args=(key_split, |
|
lang, |
|
min_len_in_seconds, |
|
max_len_in_seconds, |
|
verbose, |
|
device, |
|
phone_input, |
|
allow_unknown_symbols), |
|
daemon=True)) |
|
process_list[-1].start() |
|
for process in process_list: |
|
process.join() |
|
print("pooling results...") |
|
pooled_datapoints = list() |
|
for chunk in self.result_pool: |
|
for datapoint in chunk: |
|
pooled_datapoints.append(datapoint) |
|
self.result_pool = pooled_datapoints |
|
del pooled_datapoints |
|
print("converting text to tensors...") |
|
text_tensors = [torch.ShortTensor(x[0]) for x in self.result_pool] |
|
print("converting speech to tensors...") |
|
speech_tensors = [torch.ShortTensor(x[1]) for x in self.result_pool] |
|
print("converting waves to tensors...") |
|
norm_waves = [torch.Tensor(x[2]) for x in self.result_pool] |
|
print("unpacking file list...") |
|
filepaths = [x[3] for x in self.result_pool] |
|
del self.result_pool |
|
self.datapoints = list(zip(text_tensors, speech_tensors)) |
|
del text_tensors |
|
del speech_tensors |
|
print("done!") |
|
|
|
|
|
self.speaker_embeddings = list() |
|
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", |
|
run_opts={"device": str(device)}, |
|
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa")) |
|
with torch.inference_mode(): |
|
for wave in tqdm(norm_waves): |
|
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu()) |
|
|
|
|
|
if len(self.datapoints) == 0: |
|
raise RuntimeError |
|
torch.save((self.datapoints, None, self.speaker_embeddings, filepaths), |
|
os.path.join(cache_dir, "aligner_train_cache.pt")) |
|
|
|
def _cache_builder_process(self, |
|
path_list, |
|
lang, |
|
min_len, |
|
max_len, |
|
verbose, |
|
device, |
|
phone_input, |
|
allow_unknown_symbols): |
|
process_internal_dataset_chunk = list() |
|
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
|
|
|
silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', |
|
model='silero_vad', |
|
force_reload=False, |
|
onnx=False, |
|
verbose=False) |
|
(get_speech_timestamps, |
|
save_audio, |
|
read_audio, |
|
VADIterator, |
|
collect_chunks) = utils |
|
torch.set_grad_enabled(True) |
|
|
|
silero_model = silero_model.to(device) |
|
silence = torch.zeros([16000 // 8]).to(device) |
|
tf = ArticulatoryCombinedTextFrontend(language=lang, device=device) |
|
_, sr = sf.read(path_list[0]) |
|
assumed_sr = sr |
|
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) |
|
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device) |
|
|
|
for path in tqdm(path_list): |
|
if self.path_to_transcript_dict[path].strip() == "": |
|
continue |
|
try: |
|
wave, sr = sf.read(path) |
|
except: |
|
print(f"Problem with an audio file: {path}") |
|
continue |
|
|
|
if len(wave.shape) > 1: |
|
if len(wave[0]) == 2: |
|
wave = wave.transpose() |
|
wave = librosa.to_mono(wave) |
|
|
|
if sr != assumed_sr: |
|
assumed_sr = sr |
|
ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) |
|
resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device) |
|
print(f"{path} has a different sampling rate --> adapting the codec processor") |
|
|
|
try: |
|
norm_wave = resample(torch.tensor(wave).float().to(device)) |
|
except ValueError: |
|
continue |
|
dur_in_seconds = len(norm_wave) / 16000 |
|
if not (min_len <= dur_in_seconds <= max_len): |
|
if verbose: |
|
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") |
|
continue |
|
with torch.inference_mode(): |
|
speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000) |
|
try: |
|
silence_timestamps = invert_segments(speech_timestamps, len(norm_wave)) |
|
for silence_timestamp in silence_timestamps: |
|
begin = silence_timestamp['start'] |
|
end = silence_timestamp['end'] |
|
norm_wave = torch.cat([norm_wave[:begin], torch.zeros([end - begin], device=device), norm_wave[end:]]) |
|
result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']] |
|
except IndexError: |
|
print("Audio might be too short to cut silences from front and back.") |
|
continue |
|
norm_wave = torch.cat([silence, result, silence]) |
|
|
|
|
|
transcript = self.path_to_transcript_dict[path] |
|
|
|
try: |
|
try: |
|
cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy() |
|
except KeyError: |
|
cached_text = tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy() |
|
if not allow_unknown_symbols: |
|
continue |
|
except ValueError: |
|
|
|
continue |
|
except KeyError: |
|
|
|
continue |
|
|
|
cached_speech = ap.audio_to_codebook_indexes(audio=norm_wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy() |
|
process_internal_dataset_chunk.append([cached_text, |
|
cached_speech, |
|
norm_wave.cpu().detach().numpy(), |
|
path]) |
|
self.result_pool.append(process_internal_dataset_chunk) |
|
|
|
def __getitem__(self, index): |
|
text_vector = self.datapoints[index][0] |
|
tokens = self.tf.text_vectors_to_id_sequence(text_vector=text_vector) |
|
tokens = torch.LongTensor(tokens) |
|
token_len = torch.LongTensor([len(tokens)]) |
|
|
|
codes = self.datapoints[index][1] |
|
if codes.size()[0] != 24: |
|
codes = codes.transpose(0, 1) |
|
|
|
return tokens, \ |
|
token_len, \ |
|
codes, \ |
|
None, \ |
|
self.speaker_embeddings[index] |
|
|
|
def __len__(self): |
|
return len(self.datapoints) |
|
|
|
def remove_samples(self, list_of_samples_to_remove): |
|
for remove_id in sorted(list_of_samples_to_remove, reverse=True): |
|
self.datapoints.pop(remove_id) |
|
self.speaker_embeddings.pop(remove_id) |
|
self.filepaths.pop(remove_id) |
|
torch.save((self.datapoints, None, self.speaker_embeddings, self.filepaths), |
|
os.path.join(self.cache_dir, "aligner_train_cache.pt")) |
|
print("Dataset updated!") |
|
|
|
|
|
def fisher_yates_shuffle(lst): |
|
for i in range(len(lst) - 1, 0, -1): |
|
j = random.randint(0, i) |
|
lst[i], lst[j] = lst[j], lst[i] |
|
|
|
|
|
def invert_segments(segments, total_duration): |
|
if not segments: |
|
return [{'start': 0, 'end': total_duration}] |
|
|
|
inverted_segments = [] |
|
previous_end = 0 |
|
|
|
for segment in segments: |
|
start = segment['start'] |
|
if previous_end < start: |
|
inverted_segments.append({'start': previous_end, 'end': start}) |
|
previous_end = segment['end'] |
|
|
|
if previous_end < total_duration: |
|
inverted_segments.append({'start': previous_end, 'end': total_duration}) |
|
|
|
return inverted_segments |