fix
Browse files- data/tokenizer.py +0 -260
data/tokenizer.py
CHANGED
@@ -22,160 +22,6 @@ import torch
|
|
22 |
import torchaudio
|
23 |
from encodec import EncodecModel
|
24 |
from encodec.utils import convert_audio
|
25 |
-
from phonemizer.backend import EspeakBackend
|
26 |
-
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
27 |
-
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
28 |
-
from phonemizer.punctuation import Punctuation
|
29 |
-
from phonemizer.separator import Separator
|
30 |
-
from phonemizer.separator import Separator
|
31 |
-
|
32 |
-
try:
|
33 |
-
from pypinyin import Style, pinyin
|
34 |
-
from pypinyin.style._utils import get_finals, get_initials
|
35 |
-
except Exception:
|
36 |
-
pass
|
37 |
-
|
38 |
-
|
39 |
-
class PypinyinBackend:
|
40 |
-
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
|
41 |
-
There are two types pinyin or initials_finals, one is
|
42 |
-
just like "ni1 hao3", the other is like "n i1 h ao3".
|
43 |
-
"""
|
44 |
-
|
45 |
-
def __init__(
|
46 |
-
self,
|
47 |
-
backend="initials_finals",
|
48 |
-
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
49 |
-
) -> None:
|
50 |
-
self.backend = backend
|
51 |
-
self.punctuation_marks = punctuation_marks
|
52 |
-
|
53 |
-
def phonemize(
|
54 |
-
self, text: List[str], separator: Separator, strip=True, njobs=1
|
55 |
-
) -> List[str]:
|
56 |
-
assert isinstance(text, List)
|
57 |
-
phonemized = []
|
58 |
-
for _text in text:
|
59 |
-
_text = re.sub(" +", " ", _text.strip())
|
60 |
-
_text = _text.replace(" ", separator.word)
|
61 |
-
phones = []
|
62 |
-
if self.backend == "pypinyin":
|
63 |
-
for n, py in enumerate(
|
64 |
-
pinyin(
|
65 |
-
_text, style=Style.TONE3, neutral_tone_with_five=True
|
66 |
-
)
|
67 |
-
):
|
68 |
-
if all([c in self.punctuation_marks for c in py[0]]):
|
69 |
-
if len(phones):
|
70 |
-
assert phones[-1] == separator.syllable
|
71 |
-
phones.pop(-1)
|
72 |
-
|
73 |
-
phones.extend(list(py[0]))
|
74 |
-
else:
|
75 |
-
phones.extend([py[0], separator.syllable])
|
76 |
-
elif self.backend == "pypinyin_initials_finals":
|
77 |
-
for n, py in enumerate(
|
78 |
-
pinyin(
|
79 |
-
_text, style=Style.TONE3, neutral_tone_with_five=True
|
80 |
-
)
|
81 |
-
):
|
82 |
-
if all([c in self.punctuation_marks for c in py[0]]):
|
83 |
-
if len(phones):
|
84 |
-
assert phones[-1] == separator.syllable
|
85 |
-
phones.pop(-1)
|
86 |
-
phones.extend(list(py[0]))
|
87 |
-
else:
|
88 |
-
if py[0][-1].isalnum():
|
89 |
-
initial = get_initials(py[0], strict=False)
|
90 |
-
if py[0][-1].isdigit():
|
91 |
-
final = (
|
92 |
-
get_finals(py[0][:-1], strict=False)
|
93 |
-
+ py[0][-1]
|
94 |
-
)
|
95 |
-
else:
|
96 |
-
final = get_finals(py[0], strict=False)
|
97 |
-
phones.extend(
|
98 |
-
[
|
99 |
-
initial,
|
100 |
-
separator.phone,
|
101 |
-
final,
|
102 |
-
separator.syllable,
|
103 |
-
]
|
104 |
-
)
|
105 |
-
else:
|
106 |
-
assert ValueError
|
107 |
-
else:
|
108 |
-
raise NotImplementedError
|
109 |
-
phonemized.append(
|
110 |
-
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
|
111 |
-
)
|
112 |
-
return phonemized
|
113 |
-
|
114 |
-
|
115 |
-
class TextTokenizer:
|
116 |
-
"""Phonemize Text."""
|
117 |
-
|
118 |
-
def __init__(
|
119 |
-
self,
|
120 |
-
language="en-us",
|
121 |
-
backend="espeak",
|
122 |
-
separator=Separator(word="_", syllable="-", phone="|"),
|
123 |
-
preserve_punctuation=True,
|
124 |
-
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
125 |
-
with_stress: bool = False,
|
126 |
-
tie: Union[bool, str] = False,
|
127 |
-
language_switch: LanguageSwitch = "keep-flags",
|
128 |
-
words_mismatch: WordMismatch = "ignore",
|
129 |
-
) -> None:
|
130 |
-
if backend == "espeak":
|
131 |
-
phonemizer = EspeakBackend(
|
132 |
-
language,
|
133 |
-
punctuation_marks=punctuation_marks,
|
134 |
-
preserve_punctuation=preserve_punctuation,
|
135 |
-
with_stress=with_stress,
|
136 |
-
tie=tie,
|
137 |
-
language_switch=language_switch,
|
138 |
-
words_mismatch=words_mismatch,
|
139 |
-
)
|
140 |
-
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
|
141 |
-
phonemizer = PypinyinBackend(
|
142 |
-
backend=backend,
|
143 |
-
punctuation_marks=punctuation_marks + separator.word,
|
144 |
-
)
|
145 |
-
else:
|
146 |
-
raise NotImplementedError(f"{backend}")
|
147 |
-
|
148 |
-
self.backend = phonemizer
|
149 |
-
self.separator = separator
|
150 |
-
|
151 |
-
def to_list(self, phonemized: str) -> List[str]:
|
152 |
-
fields = []
|
153 |
-
for word in phonemized.split(self.separator.word):
|
154 |
-
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
155 |
-
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
156 |
-
fields.extend(
|
157 |
-
[p for p in pp if p != self.separator.phone]
|
158 |
-
+ [self.separator.word]
|
159 |
-
)
|
160 |
-
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
161 |
-
self.separator.phone
|
162 |
-
)
|
163 |
-
return fields[:-1]
|
164 |
-
|
165 |
-
def __call__(self, text, strip=True) -> List[List[str]]:
|
166 |
-
if isinstance(text, str):
|
167 |
-
text = [text]
|
168 |
-
|
169 |
-
phonemized = self.backend.phonemize(
|
170 |
-
text, separator=self.separator, strip=strip, njobs=1
|
171 |
-
)
|
172 |
-
return [self.to_list(p) for p in phonemized]
|
173 |
-
|
174 |
-
|
175 |
-
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
176 |
-
phonemes = tokenizer([text.strip()])
|
177 |
-
return phonemes[0] # k2symbols
|
178 |
-
|
179 |
|
180 |
def remove_encodec_weight_norm(model):
|
181 |
from encodec.modules import SConv1d
|
@@ -256,112 +102,6 @@ def tokenize_audio(tokenizer: AudioTokenizer, audio):
|
|
256 |
return encoded_frames
|
257 |
|
258 |
|
259 |
-
# @dataclass
|
260 |
-
# class AudioTokenConfig:
|
261 |
-
# frame_shift: Seconds = 320.0 / 24000
|
262 |
-
# num_quantizers: int = 8
|
263 |
-
#
|
264 |
-
# def to_dict(self) -> Dict[str, Any]:
|
265 |
-
# return asdict(self)
|
266 |
-
#
|
267 |
-
# @staticmethod
|
268 |
-
# def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
|
269 |
-
# return AudioTokenConfig(**data)
|
270 |
-
#
|
271 |
-
#
|
272 |
-
# class AudioTokenExtractor(FeatureExtractor):
|
273 |
-
# name = "encodec"
|
274 |
-
# config_type = AudioTokenConfig
|
275 |
-
#
|
276 |
-
# def __init__(self, config: Optional[Any] = None):
|
277 |
-
# super(AudioTokenExtractor, self).__init__(config)
|
278 |
-
# self.tokenizer = AudioTokenizer()
|
279 |
-
#
|
280 |
-
# def extract(
|
281 |
-
# self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
282 |
-
# ) -> np.ndarray:
|
283 |
-
# if not isinstance(samples, torch.Tensor):
|
284 |
-
# samples = torch.from_numpy(samples)
|
285 |
-
# if sampling_rate != self.tokenizer.sample_rate:
|
286 |
-
# samples = convert_audio(
|
287 |
-
# samples,
|
288 |
-
# sampling_rate,
|
289 |
-
# self.tokenizer.sample_rate,
|
290 |
-
# self.tokenizer.channels,
|
291 |
-
# )
|
292 |
-
# if len(samples.shape) == 2:
|
293 |
-
# samples = samples.unsqueeze(0)
|
294 |
-
# else:
|
295 |
-
# raise ValueError()
|
296 |
-
#
|
297 |
-
# device = self.tokenizer.device
|
298 |
-
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
299 |
-
# codes = encoded_frames[0][0] # [B, n_q, T]
|
300 |
-
# if True:
|
301 |
-
# duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
302 |
-
# expected_num_frames = compute_num_frames(
|
303 |
-
# duration=duration,
|
304 |
-
# frame_shift=self.frame_shift,
|
305 |
-
# sampling_rate=sampling_rate,
|
306 |
-
# )
|
307 |
-
# assert abs(codes.shape[-1] - expected_num_frames) <= 1
|
308 |
-
# codes = codes[..., :expected_num_frames]
|
309 |
-
# return codes.cpu().squeeze(0).permute(1, 0).numpy()
|
310 |
-
#
|
311 |
-
# @property
|
312 |
-
# def frame_shift(self) -> Seconds:
|
313 |
-
# return self.config.frame_shift
|
314 |
-
#
|
315 |
-
# def feature_dim(self, sampling_rate: int) -> int:
|
316 |
-
# return self.config.num_quantizers
|
317 |
-
#
|
318 |
-
# def pad_tensor_list(self, tensor_list, device, padding_value=0):
|
319 |
-
# # 计算每个张量的长度
|
320 |
-
# lengths = [tensor.shape[0] for tensor in tensor_list]
|
321 |
-
# # 使用pad_sequence函数进行填充
|
322 |
-
# tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
|
323 |
-
# padded_tensor = torch.nn.utils.rnn.pad_sequence(
|
324 |
-
# tensor_list, batch_first=True, padding_value=padding_value
|
325 |
-
# )
|
326 |
-
# return padded_tensor, lengths
|
327 |
-
#
|
328 |
-
# def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
|
329 |
-
# samples = [wav.squeeze() for wav in samples]
|
330 |
-
# device = self.tokenizer.device
|
331 |
-
# samples, lengths = self.pad_tensor_list(samples, device)
|
332 |
-
# samples = samples.unsqueeze(1)
|
333 |
-
#
|
334 |
-
# if not isinstance(samples, torch.Tensor):
|
335 |
-
# samples = torch.from_numpy(samples)
|
336 |
-
# if len(samples.shape) != 3:
|
337 |
-
# raise ValueError()
|
338 |
-
# if sampling_rate != self.tokenizer.sample_rate:
|
339 |
-
# samples = [
|
340 |
-
# convert_audio(
|
341 |
-
# wav,
|
342 |
-
# sampling_rate,
|
343 |
-
# self.tokenizer.sample_rate,
|
344 |
-
# self.tokenizer.channels,
|
345 |
-
# )
|
346 |
-
# for wav in samples
|
347 |
-
# ]
|
348 |
-
# # Extract discrete codes from EnCodec
|
349 |
-
# with torch.no_grad():
|
350 |
-
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
351 |
-
# encoded_frames = encoded_frames[0][0] # [B, n_q, T]
|
352 |
-
# batch_codes = []
|
353 |
-
# for b, length in enumerate(lengths):
|
354 |
-
# codes = encoded_frames[b]
|
355 |
-
# duration = round(length / sampling_rate, ndigits=12)
|
356 |
-
# expected_num_frames = compute_num_frames(
|
357 |
-
# duration=duration,
|
358 |
-
# frame_shift=self.frame_shift,
|
359 |
-
# sampling_rate=sampling_rate,
|
360 |
-
# )
|
361 |
-
# batch_codes.append(codes[..., :expected_num_frames])
|
362 |
-
# return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
|
363 |
-
|
364 |
-
|
365 |
if __name__ == "__main__":
|
366 |
model = EncodecModel.encodec_model_24khz()
|
367 |
model.set_target_bandwidth(6.0)
|
|
|
22 |
import torchaudio
|
23 |
from encodec import EncodecModel
|
24 |
from encodec.utils import convert_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def remove_encodec_weight_norm(model):
|
27 |
from encodec.modules import SConv1d
|
|
|
102 |
return encoded_frames
|
103 |
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
if __name__ == "__main__":
|
106 |
model = EncodecModel.encodec_model_24khz()
|
107 |
model.set_target_bandwidth(6.0)
|