Ricecake123's picture
first commit
e79b770
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/phonemizer.py
import itertools
import re
from typing import Dict
from typing import List
import regex
from gruut import sentences
from gruut.const import Sentence
from gruut.const import Word
from AR.text_processing.symbols import SYMBOL_TO_ID
class GruutPhonemizer:
def __init__(self, language: str):
self._phonemizer = sentences
self.lang = language
self.symbol_to_id = SYMBOL_TO_ID
self._special_cases_dict: Dict[str] = {
r"\.\.\.": "... ",
";": "; ",
":": ": ",
",": ", ",
r"\.": ". ",
"!": "! ",
r"\?": "? ",
"—": "—",
"…": "… ",
"«": "«",
"»": "»"
}
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(r"\pZ+", r" ", text)
return text.strip()
def _convert_punctuation(self, word: Word) -> str:
if not word.phonemes:
return ''
if word.phonemes[0] in ['‖', '|']:
return word.text.strip()
phonemes = ''.join(word.phonemes)
# remove modifier characters ˈˌː with regex
phonemes = re.sub(r'[ˈˌː͡]', '', phonemes)
return phonemes.strip()
def phonemize(self, text: str, espeak: bool=False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [
sent
for sent in self._phonemizer(
text_to_phonemize, lang="en-us", espeak=espeak)
]
words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents)
]
return ' '.join(words)
def transform(self, phonemes):
# convert phonemes to ids
# dictionary is in symbols.py
return [
self.symbol_to_id[p] for p in phonemes
if p in self.symbol_to_id.keys()
]
if __name__ == "__main__":
phonemizer = GruutPhonemizer("en-us")
# text -> IPA
phonemes = phonemizer.phonemize("Hello, wor-ld ?")
print("phonemes:", phonemes)
print("len(phonemes):", len(phonemes))
phoneme_ids = phonemizer.transform(phonemes)
print("phoneme_ids:", phoneme_ids)
print("len(phoneme_ids):", len(phoneme_ids))