|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
from dataclasses import asdict, dataclass |
|
from typing import Any, Dict, List, Optional, Pattern, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from encodec import EncodecModel |
|
from encodec.utils import convert_audio |
|
from phonemizer.backend import EspeakBackend |
|
from phonemizer.backend.espeak.language_switch import LanguageSwitch |
|
from phonemizer.backend.espeak.words_mismatch import WordMismatch |
|
from phonemizer.punctuation import Punctuation |
|
from phonemizer.separator import Separator |
|
from phonemizer.separator import Separator |
|
|
|
try: |
|
from pypinyin import Style, pinyin |
|
from pypinyin.style._utils import get_finals, get_initials |
|
except Exception: |
|
pass |
|
|
|
|
|
class PypinyinBackend: |
|
"""PypinyinBackend for Chinese. Most codes is referenced from espnet. |
|
There are two types pinyin or initials_finals, one is |
|
just like "ni1 hao3", the other is like "n i1 h ao3". |
|
""" |
|
|
|
def __init__( |
|
self, |
|
backend="initials_finals", |
|
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), |
|
) -> None: |
|
self.backend = backend |
|
self.punctuation_marks = punctuation_marks |
|
|
|
def phonemize( |
|
self, text: List[str], separator: Separator, strip=True, njobs=1 |
|
) -> List[str]: |
|
assert isinstance(text, List) |
|
phonemized = [] |
|
for _text in text: |
|
_text = re.sub(" +", " ", _text.strip()) |
|
_text = _text.replace(" ", separator.word) |
|
phones = [] |
|
if self.backend == "pypinyin": |
|
for n, py in enumerate( |
|
pinyin( |
|
_text, style=Style.TONE3, neutral_tone_with_five=True |
|
) |
|
): |
|
if all([c in self.punctuation_marks for c in py[0]]): |
|
if len(phones): |
|
assert phones[-1] == separator.syllable |
|
phones.pop(-1) |
|
|
|
phones.extend(list(py[0])) |
|
else: |
|
phones.extend([py[0], separator.syllable]) |
|
elif self.backend == "pypinyin_initials_finals": |
|
for n, py in enumerate( |
|
pinyin( |
|
_text, style=Style.TONE3, neutral_tone_with_five=True |
|
) |
|
): |
|
if all([c in self.punctuation_marks for c in py[0]]): |
|
if len(phones): |
|
assert phones[-1] == separator.syllable |
|
phones.pop(-1) |
|
phones.extend(list(py[0])) |
|
else: |
|
if py[0][-1].isalnum(): |
|
initial = get_initials(py[0], strict=False) |
|
if py[0][-1].isdigit(): |
|
final = ( |
|
get_finals(py[0][:-1], strict=False) |
|
+ py[0][-1] |
|
) |
|
else: |
|
final = get_finals(py[0], strict=False) |
|
phones.extend( |
|
[ |
|
initial, |
|
separator.phone, |
|
final, |
|
separator.syllable, |
|
] |
|
) |
|
else: |
|
assert ValueError |
|
else: |
|
raise NotImplementedError |
|
phonemized.append( |
|
"".join(phones).rstrip(f"{separator.word}{separator.syllable}") |
|
) |
|
return phonemized |
|
|
|
|
|
class TextTokenizer: |
|
"""Phonemize Text.""" |
|
|
|
def __init__( |
|
self, |
|
language="en-us", |
|
backend="espeak", |
|
separator=Separator(word="_", syllable="-", phone="|"), |
|
preserve_punctuation=True, |
|
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), |
|
with_stress: bool = False, |
|
tie: Union[bool, str] = False, |
|
language_switch: LanguageSwitch = "keep-flags", |
|
words_mismatch: WordMismatch = "ignore", |
|
) -> None: |
|
if backend == "espeak": |
|
phonemizer = EspeakBackend( |
|
language, |
|
punctuation_marks=punctuation_marks, |
|
preserve_punctuation=preserve_punctuation, |
|
with_stress=with_stress, |
|
tie=tie, |
|
language_switch=language_switch, |
|
words_mismatch=words_mismatch, |
|
) |
|
elif backend in ["pypinyin", "pypinyin_initials_finals"]: |
|
phonemizer = PypinyinBackend( |
|
backend=backend, |
|
punctuation_marks=punctuation_marks + separator.word, |
|
) |
|
else: |
|
raise NotImplementedError(f"{backend}") |
|
|
|
self.backend = phonemizer |
|
self.separator = separator |
|
|
|
def to_list(self, phonemized: str) -> List[str]: |
|
fields = [] |
|
for word in phonemized.split(self.separator.word): |
|
|
|
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) |
|
fields.extend( |
|
[p for p in pp if p != self.separator.phone] |
|
+ [self.separator.word] |
|
) |
|
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( |
|
self.separator.phone |
|
) |
|
return fields[:-1] |
|
|
|
def __call__(self, text, strip=True) -> List[List[str]]: |
|
if isinstance(text, str): |
|
text = [text] |
|
|
|
phonemized = self.backend.phonemize( |
|
text, separator=self.separator, strip=strip, njobs=1 |
|
) |
|
return [self.to_list(p) for p in phonemized] |
|
|
|
|
|
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: |
|
phonemes = tokenizer([text.strip()]) |
|
return phonemes[0] |
|
|
|
|
|
def remove_encodec_weight_norm(model): |
|
from encodec.modules import SConv1d |
|
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock |
|
from torch.nn.utils import remove_weight_norm |
|
|
|
encoder = model.encoder.model |
|
for key in encoder._modules: |
|
if isinstance(encoder._modules[key], SEANetResnetBlock): |
|
remove_weight_norm(encoder._modules[key].shortcut.conv.conv) |
|
block_modules = encoder._modules[key].block._modules |
|
for skey in block_modules: |
|
if isinstance(block_modules[skey], SConv1d): |
|
remove_weight_norm(block_modules[skey].conv.conv) |
|
elif isinstance(encoder._modules[key], SConv1d): |
|
remove_weight_norm(encoder._modules[key].conv.conv) |
|
|
|
decoder = model.decoder.model |
|
for key in decoder._modules: |
|
if isinstance(decoder._modules[key], SEANetResnetBlock): |
|
remove_weight_norm(decoder._modules[key].shortcut.conv.conv) |
|
block_modules = decoder._modules[key].block._modules |
|
for skey in block_modules: |
|
if isinstance(block_modules[skey], SConv1d): |
|
remove_weight_norm(block_modules[skey].conv.conv) |
|
elif isinstance(decoder._modules[key], SConvTranspose1d): |
|
remove_weight_norm(decoder._modules[key].convtr.convtr) |
|
elif isinstance(decoder._modules[key], SConv1d): |
|
remove_weight_norm(decoder._modules[key].conv.conv) |
|
|
|
|
|
class AudioTokenizer: |
|
"""EnCodec audio.""" |
|
|
|
def __init__( |
|
self, |
|
device: Any = None, |
|
) -> None: |
|
|
|
model = EncodecModel.encodec_model_24khz() |
|
model.set_target_bandwidth(6.0) |
|
remove_encodec_weight_norm(model) |
|
|
|
if not device: |
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
|
|
self._device = device |
|
|
|
self.codec = model.to(device) |
|
self.sample_rate = model.sample_rate |
|
self.channels = model.channels |
|
|
|
@property |
|
def device(self): |
|
return self._device |
|
|
|
def encode(self, wav: torch.Tensor) -> torch.Tensor: |
|
return self.codec.encode(wav.to(self.device)) |
|
|
|
def decode(self, frames: torch.Tensor) -> torch.Tensor: |
|
return self.codec.decode(frames) |
|
|
|
|
|
def tokenize_audio(tokenizer: AudioTokenizer, audio): |
|
|
|
if isinstance(audio, str): |
|
wav, sr = torchaudio.load(audio) |
|
else: |
|
wav, sr = audio |
|
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) |
|
wav = wav.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
encoded_frames = tokenizer.encode(wav) |
|
return encoded_frames |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
model = EncodecModel.encodec_model_24khz() |
|
model.set_target_bandwidth(6.0) |
|
|
|
samples = torch.from_numpy(np.random.random([4, 1, 1600])).type( |
|
torch.float32 |
|
) |
|
codes_raw = model.encode(samples) |
|
|
|
remove_encodec_weight_norm(model) |
|
codes_norm = model.encode(samples) |
|
|
|
assert torch.allclose(codes_raw[0][0], codes_norm[0][0]) |
|
|