Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/processor
/synpaflex.py
# -*- coding: utf-8 -*- | |
# Copyright 2020 TensorFlowTTS Team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Perform preprocessing and raw feature extraction for SynPaFlex dataset.""" | |
import os | |
import re | |
import numpy as np | |
import soundfile as sf | |
from dataclasses import dataclass | |
from tensorflow_tts.processor import BaseProcessor | |
from tensorflow_tts.utils import cleaners | |
_pad = "pad" | |
_eos = "eos" | |
_punctuation = "!/\'(),-.:;? " | |
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzéèàùâêîôûçäëïöüÿœæ" | |
# Export all symbols: | |
SYNPAFLEX_SYMBOLS = ( | |
[_pad] + list(_punctuation) + list(_letters) + [_eos] | |
) | |
# Regular expression matching text enclosed in curly braces: | |
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") | |
class SynpaflexProcessor(BaseProcessor): | |
"""SynPaFlex processor.""" | |
cleaner_names: str = "basic_cleaners" | |
positions = { | |
"wave_file": 0, | |
"text": 1, | |
"text_norm": 2 | |
} | |
train_f_name: str = "synpaflex.txt" | |
def create_items(self): | |
if self.data_dir: | |
with open( | |
os.path.join(self.data_dir, self.train_f_name), encoding="utf-8" | |
) as f: | |
self.items = [self.split_line(self.data_dir, line, "|") for line in f] | |
def split_line(self, data_dir, line, split): | |
parts = line.strip().split(split) | |
wave_file = parts[self.positions["wave_file"]] | |
text = parts[self.positions["text"]] | |
wav_path = os.path.join(data_dir, "wavs", f"{wave_file}.wav") | |
speaker_name = "synpaflex" | |
return text, wav_path, speaker_name | |
def setup_eos_token(self): | |
return _eos | |
def get_one_sample(self, item): | |
text, wav_path, speaker_name = item | |
# normalize audio signal to be [-1, 1], soundfile already norm. | |
audio, rate = sf.read(wav_path) | |
audio = audio.astype(np.float32) | |
# convert text to ids | |
text_ids = np.asarray(self.text_to_sequence(text), np.int32) | |
sample = { | |
"raw_text": text, | |
"text_ids": text_ids, | |
"audio": audio, | |
"utt_id": os.path.split(wav_path)[-1].split(".")[0], | |
"speaker_name": speaker_name, | |
"rate": rate, | |
} | |
return sample | |
def text_to_sequence(self, text): | |
sequence = [] | |
# Check for curly braces and treat their contents as ARPAbet: | |
while len(text): | |
m = _curly_re.match(text) | |
if not m: | |
sequence += self._symbols_to_sequence( | |
self._clean_text(text, [self.cleaner_names]) | |
) | |
break | |
sequence += self._symbols_to_sequence( | |
self._clean_text(m.group(1), [self.cleaner_names]) | |
) | |
sequence += self._arpabet_to_sequence(m.group(2)) | |
text = m.group(3) | |
# add eos tokens | |
sequence += [self.eos_id] | |
return sequence | |
def _clean_text(self, text, cleaner_names): | |
for name in cleaner_names: | |
cleaner = getattr(cleaners, name) | |
if not cleaner: | |
raise Exception("Unknown cleaner: %s" % name) | |
text = cleaner(text) | |
return text | |
def _symbols_to_sequence(self, symbols): | |
return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)] | |
def _sequence_to_symbols(self, sequence): | |
return [self.id_to_symbol[s] for s in sequence] | |
def _arpabet_to_sequence(self, text): | |
return self._symbols_to_sequence(["@" + s for s in text.split()]) | |
def _should_keep_symbol(self, s): | |
return s in self.symbol_to_id and s != "_" and s != "~" | |
def save_pretrained(self, saved_path): | |
os.makedirs(saved_path, exist_ok=True) | |
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {}) | |