|
import os |
|
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
from data_gen.tts.txt_processors.zh_g2pM import ALL_SHENMU |
|
from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError |
|
from data_gen.tts.data_gen_utils import get_mel2ph |
|
from utils.hparams import set_hparams, hparams |
|
import numpy as np |
|
|
|
|
|
class ZhBinarizer(BaseBinarizer): |
|
@staticmethod |
|
def get_align(tg_fn, ph, mel, phone_encoded, res): |
|
if tg_fn is not None and os.path.exists(tg_fn): |
|
_, dur = get_mel2ph(tg_fn, ph, mel, hparams) |
|
else: |
|
raise BinarizationError(f"Align not found") |
|
ph_list = ph.split(" ") |
|
assert len(dur) == len(ph_list) |
|
mel2ph = [] |
|
|
|
dur_cumsum = np.pad(np.cumsum(dur), [1, 0], mode='constant', constant_values=0) |
|
for i in range(len(dur)): |
|
p = ph_list[i] |
|
if p[0] != '<' and not p[0].isalpha(): |
|
uv_ = res['f0'][dur_cumsum[i]:dur_cumsum[i + 1]] == 0 |
|
j = 0 |
|
while j < len(uv_) and not uv_[j]: |
|
j += 1 |
|
dur[i - 1] += j |
|
dur[i] -= j |
|
if dur[i] < 100: |
|
dur[i - 1] += dur[i] |
|
dur[i] = 0 |
|
|
|
for i in range(len(dur)): |
|
p = ph_list[i] |
|
if p in ALL_SHENMU: |
|
p_next = ph_list[i + 1] |
|
if not (dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU): |
|
print(f"assert dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU, " |
|
f"dur[i]: {dur[i]}, p: {p}, p_next: {p_next}.") |
|
continue |
|
total = dur[i + 1] + dur[i] |
|
dur[i] = total // 2 |
|
dur[i + 1] = total - dur[i] |
|
for i in range(len(dur)): |
|
mel2ph += [i + 1] * dur[i] |
|
mel2ph = np.array(mel2ph) |
|
if mel2ph.max() - 1 >= len(phone_encoded): |
|
raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone_encoded))}") |
|
res['mel2ph'] = mel2ph |
|
res['dur'] = dur |
|
|
|
|
|
if __name__ == "__main__": |
|
set_hparams() |
|
ZhBinarizer().process() |
|
|