File size: 1,850 Bytes
f42d33e
fcf8202
f42d33e
77efc8b
f42d33e
 
f5b630a
f42d33e
f5b630a
 
 
 
 
 
 
 
f42d33e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5b630a
 
f42d33e
77efc8b
f42d33e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch,json,os

from scipy.io import wavfile

import numpy as np
with open("dict_han_pinyin.json","r",encoding="utf-8") as f:
    print("loading")
    data_dict = json.load(f)


chinese_models = {
    "csukuangfj/vits-piper-zh_CN-huayan-medium": 1}
language_to_models = {
    "Chinese (Mandarin, 普通话)": list(chinese_models.keys())}


def normalize(wav):
    assert wav.dtype == np.float32
    eps = 1e-6
    sil = wav[1500:2000]
    #wav = wav - np.mean(sil)
    #wav = (wav - np.min(wav))/(np.max(wav)-np.min(wav)+eps)
    wav = wav / np.max(np.abs(wav))
    #wav = wav*2-1
    wav = wav * 32767
    return wav.astype('int16')


def to_int16(wav):
    wav = wav = wav * 32767
    wav = np.clamp(wav, -32767, 32768)
    return wav.astype('int16')



def get_pretrained_model(model,line,config,text_processor,vocoder):
    sr = config['fbank']['sample_rate']
    pinyin = ""
    hanzi = ""
    for i in line:
        pinyin+=data_dict[i]+" "
        hanzi +=i+" "
    post_line = f"text1|sil {pinyin}sil|sil {hanzi}sil|0" 
    name, tokens = text_processor(post_line)
    tokens = tokens.to("cpu")
    seq_len = torch.tensor([tokens.shape[1]])
    tokens = tokens.unsqueeze(1)
    seq_len = seq_len.to("cpu")
    max_src_len = torch.max(seq_len)
    output = model(tokens, seq_len, max_src_len=max_src_len, d_control=1.0)
    mel_pred, mel_postnet, d_pred, src_mask, mel_mask, mel_len = output

    # convert to waveform using vocoder
    mel_postnet = mel_postnet[0].transpose(0, 1).detach()
    mel_postnet += config['fbank']['mel_mean']
    wav = vocoder(mel_postnet)
    if config['synthesis']['normalize']:
        wav = normalize(wav)
    else:
        wav = to_int16(wav)
    dst_file = os.path.join(f'{name}.wav')
    #np.save(dst_file+'.npy',mel_postnet.cpu().numpy())
    wavfile.write(dst_file, sr, wav)
    return dst_file,2.0