Spaces:
Sleeping
Sleeping
from functools import lru_cache | |
import torch,json,os | |
import yaml | |
from scipy.io import wavfile | |
from mtts.text import TextProcessor | |
from mtts.models.fs2_model import FastSpeech2 | |
import numpy as np | |
with open("dict_han_pinyin.json","r",encoding="utf-8") as f: | |
data_dict = json.load(f) | |
def normalize(wav): | |
assert wav.dtype == np.float32 | |
eps = 1e-6 | |
sil = wav[1500:2000] | |
#wav = wav - np.mean(sil) | |
#wav = (wav - np.min(wav))/(np.max(wav)-np.min(wav)+eps) | |
wav = wav / np.max(np.abs(wav)) | |
#wav = wav*2-1 | |
wav = wav * 32767 | |
return wav.astype('int16') | |
def to_int16(wav): | |
wav = wav = wav * 32767 | |
wav = np.clamp(wav, -32767, 32768) | |
return wav.astype('int16') | |
def __build_vocoder(config): | |
vocoder_name = config['vocoder']['type'] | |
VocoderClass = eval(vocoder_name) | |
model = VocoderClass(config=config['vocoder'][vocoder_name]) | |
return model | |
def get_pretrained_model(line): | |
config = "text_to_speech\examples/biaobei\config.yaml" | |
checkpoint = "text_to_speech\checkpoints\checkpoint_140000.pth.tar" | |
with open(config) as f: | |
config = yaml.safe_load(f) | |
sr = config['fbank']['sample_rate'] | |
vocoder = __build_vocoder(config) | |
text_processor = TextProcessor(config) | |
model = FastSpeech2(config) | |
if checkpoint != '': | |
sd = torch.load(checkpoint, map_location="cpu") | |
if 'model' in sd.keys(): | |
sd = sd['model'] | |
model.load_state_dict(sd) | |
del sd # to save mem | |
model = model.to("cpu") | |
torch.set_grad_enabled(False) | |
pinyin = "" | |
hanzi = "" | |
for i in line: | |
pinyin+=data_dict[i]+" " | |
hanzi +=i+" " | |
post_line = f"text1|sil {pinyin}sil|sil {hanzi}sil|0" | |
name, tokens = text_processor(post_line) | |
tokens = tokens.to("cpu") | |
seq_len = torch.tensor([tokens.shape[1]]) | |
tokens = tokens.unsqueeze(1) | |
seq_len = seq_len.to("cpu") | |
max_src_len = torch.max(seq_len) | |
output = model(tokens, seq_len, max_src_len=max_src_len, d_control=1.0) | |
mel_pred, mel_postnet, d_pred, src_mask, mel_mask, mel_len = output | |
# convert to waveform using vocoder | |
mel_postnet = mel_postnet[0].transpose(0, 1).detach() | |
mel_postnet += config['fbank']['mel_mean'] | |
wav = vocoder(mel_postnet) | |
if config['synthesis']['normalize']: | |
wav = normalize(wav) | |
else: | |
wav = to_int16(wav) | |
dst_file = os.path.join(f'{name}.wav') | |
#np.save(dst_file+'.npy',mel_postnet.cpu().numpy()) | |
wavfile.write(dst_file, sr, wav) | |
return dst_file,2.0 | |
chinese_models = { | |
"csukuangfj/vits-piper-zh_CN-huayan-medium": 1} | |
language_to_models = { | |
"Chinese (Mandarin, 普通话)": list(chinese_models.keys())} |