VALL-E-X / data /tokenizer.py
Plachta's picture
initial commit
b1e1a76
raw
history blame
13.9 kB
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Pattern, Union
import numpy as np
import torch
import torchaudio
from encodec import EncodecModel
from encodec.utils import convert_audio
from phonemizer.backend import EspeakBackend
from phonemizer.backend.espeak.language_switch import LanguageSwitch
from phonemizer.backend.espeak.words_mismatch import WordMismatch
from phonemizer.punctuation import Punctuation
from phonemizer.separator import Separator
from phonemizer.separator import Separator
try:
from pypinyin import Style, pinyin
from pypinyin.style._utils import get_finals, get_initials
except Exception:
pass
class PypinyinBackend:
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
There are two types pinyin or initials_finals, one is
just like "ni1 hao3", the other is like "n i1 h ao3".
"""
def __init__(
self,
backend="initials_finals",
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
) -> None:
self.backend = backend
self.punctuation_marks = punctuation_marks
def phonemize(
self, text: List[str], separator: Separator, strip=True, njobs=1
) -> List[str]:
assert isinstance(text, List)
phonemized = []
for _text in text:
_text = re.sub(" +", " ", _text.strip())
_text = _text.replace(" ", separator.word)
phones = []
if self.backend == "pypinyin":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
phones.extend([py[0], separator.syllable])
elif self.backend == "pypinyin_initials_finals":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
if py[0][-1].isalnum():
initial = get_initials(py[0], strict=False)
if py[0][-1].isdigit():
final = (
get_finals(py[0][:-1], strict=False)
+ py[0][-1]
)
else:
final = get_finals(py[0], strict=False)
phones.extend(
[
initial,
separator.phone,
final,
separator.syllable,
]
)
else:
assert ValueError
else:
raise NotImplementedError
phonemized.append(
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
)
return phonemized
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
if backend == "espeak":
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
phonemizer = PypinyinBackend(
backend=backend,
punctuation_marks=punctuation_marks + separator.word,
)
else:
raise NotImplementedError(f"{backend}")
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone]
+ [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
) -> None:
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
remove_encodec_weight_norm(model)
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
self.sample_rate = model.sample_rate
self.channels = model.channels
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.codec.encode(wav.to(self.device))
def decode(self, frames: torch.Tensor) -> torch.Tensor:
return self.codec.decode(frames)
def tokenize_audio(tokenizer: AudioTokenizer, audio):
# Load and pre-process the audio waveform
if isinstance(audio, str):
wav, sr = torchaudio.load(audio)
else:
wav, sr = audio
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
wav = wav.unsqueeze(0)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = tokenizer.encode(wav)
return encoded_frames
# @dataclass
# class AudioTokenConfig:
# frame_shift: Seconds = 320.0 / 24000
# num_quantizers: int = 8
#
# def to_dict(self) -> Dict[str, Any]:
# return asdict(self)
#
# @staticmethod
# def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
# return AudioTokenConfig(**data)
#
#
# class AudioTokenExtractor(FeatureExtractor):
# name = "encodec"
# config_type = AudioTokenConfig
#
# def __init__(self, config: Optional[Any] = None):
# super(AudioTokenExtractor, self).__init__(config)
# self.tokenizer = AudioTokenizer()
#
# def extract(
# self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
# ) -> np.ndarray:
# if not isinstance(samples, torch.Tensor):
# samples = torch.from_numpy(samples)
# if sampling_rate != self.tokenizer.sample_rate:
# samples = convert_audio(
# samples,
# sampling_rate,
# self.tokenizer.sample_rate,
# self.tokenizer.channels,
# )
# if len(samples.shape) == 2:
# samples = samples.unsqueeze(0)
# else:
# raise ValueError()
#
# device = self.tokenizer.device
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
# codes = encoded_frames[0][0] # [B, n_q, T]
# if True:
# duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
# expected_num_frames = compute_num_frames(
# duration=duration,
# frame_shift=self.frame_shift,
# sampling_rate=sampling_rate,
# )
# assert abs(codes.shape[-1] - expected_num_frames) <= 1
# codes = codes[..., :expected_num_frames]
# return codes.cpu().squeeze(0).permute(1, 0).numpy()
#
# @property
# def frame_shift(self) -> Seconds:
# return self.config.frame_shift
#
# def feature_dim(self, sampling_rate: int) -> int:
# return self.config.num_quantizers
#
# def pad_tensor_list(self, tensor_list, device, padding_value=0):
# # 计算每个张量的长度
# lengths = [tensor.shape[0] for tensor in tensor_list]
# # 使用pad_sequence函数进行填充
# tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
# padded_tensor = torch.nn.utils.rnn.pad_sequence(
# tensor_list, batch_first=True, padding_value=padding_value
# )
# return padded_tensor, lengths
#
# def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
# samples = [wav.squeeze() for wav in samples]
# device = self.tokenizer.device
# samples, lengths = self.pad_tensor_list(samples, device)
# samples = samples.unsqueeze(1)
#
# if not isinstance(samples, torch.Tensor):
# samples = torch.from_numpy(samples)
# if len(samples.shape) != 3:
# raise ValueError()
# if sampling_rate != self.tokenizer.sample_rate:
# samples = [
# convert_audio(
# wav,
# sampling_rate,
# self.tokenizer.sample_rate,
# self.tokenizer.channels,
# )
# for wav in samples
# ]
# # Extract discrete codes from EnCodec
# with torch.no_grad():
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
# encoded_frames = encoded_frames[0][0] # [B, n_q, T]
# batch_codes = []
# for b, length in enumerate(lengths):
# codes = encoded_frames[b]
# duration = round(length / sampling_rate, ndigits=12)
# expected_num_frames = compute_num_frames(
# duration=duration,
# frame_shift=self.frame_shift,
# sampling_rate=sampling_rate,
# )
# batch_codes.append(codes[..., :expected_num_frames])
# return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
if __name__ == "__main__":
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
torch.float32
)
codes_raw = model.encode(samples)
remove_encodec_weight_norm(model)
codes_norm = model.encode(samples)
assert torch.allclose(codes_raw[0][0], codes_norm[0][0])