File size: 3,852 Bytes
524ad56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import json

import torch

import utils
from onnxexport.model_onnx_speaker_mix import SynthesizerTrn


def main():
    path = "crs"

    device = torch.device("cpu")
    hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
    SVCVITS = SynthesizerTrn(
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        **hps.model)
    _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
    _ = SVCVITS.eval().to(device)
    for i in SVCVITS.parameters():
        i.requires_grad = False
    
    num_frames = 200

    test_hidden_unit = torch.rand(1, num_frames, SVCVITS.gin_channels)
    test_pitch = torch.rand(1, num_frames)
    test_vol = torch.rand(1, num_frames)
    test_mel2ph = torch.LongTensor(torch.arange(0, num_frames)).unsqueeze(0)
    test_uv = torch.ones(1, num_frames, dtype=torch.float32)
    test_noise = torch.randn(1, 192, num_frames)
    test_sid = torch.LongTensor([0])
    export_mix = True
    if len(hps.spk) < 2:
        export_mix = False
    
    if export_mix:
        spk_mix = []
        n_spk = len(hps.spk)
        for i in range(n_spk):
            spk_mix.append(1.0/float(n_spk))
        test_sid = torch.tensor(spk_mix)
        SVCVITS.export_chara_mix(hps.spk)
        test_sid = test_sid.unsqueeze(0)
        test_sid = test_sid.repeat(num_frames, 1)
    
    SVCVITS.eval()

    if export_mix:
        daxes = {
            "c": [0, 1],
            "f0": [1],
            "mel2ph": [1],
            "uv": [1],
            "noise": [2],
            "sid":[0]
        }
    else:
        daxes = {
            "c": [0, 1],
            "f0": [1],
            "mel2ph": [1],
            "uv": [1],
            "noise": [2]
        }
    
    input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
    output_names = ["audio", ]

    if SVCVITS.vol_embedding:
        input_names.append("vol")
        vol_dadict = {"vol" : [1]}
        daxes.update(vol_dadict)
        test_inputs = (
            test_hidden_unit.to(device),
            test_pitch.to(device),
            test_mel2ph.to(device),
            test_uv.to(device),
            test_noise.to(device),
            test_sid.to(device),
            test_vol.to(device)
        )
    else:
        test_inputs = (
            test_hidden_unit.to(device),
            test_pitch.to(device),
            test_mel2ph.to(device),
            test_uv.to(device),
            test_noise.to(device),
            test_sid.to(device)
        )

    # SVCVITS = torch.jit.script(SVCVITS)
    SVCVITS(test_hidden_unit.to(device),
            test_pitch.to(device),
            test_mel2ph.to(device),
            test_uv.to(device),
            test_noise.to(device),
            test_sid.to(device),
            test_vol.to(device))

    torch.onnx.export(
        SVCVITS,
        test_inputs,
        f"checkpoints/{path}/{path}_SoVits.onnx",
        dynamic_axes=daxes,
        do_constant_folding=False,
        opset_version=16,
        verbose=False,
        input_names=input_names,
        output_names=output_names
    )

    vec_lay = "layer-12" if SVCVITS.gin_channels == 768 else "layer-9"
    spklist = []
    for key in hps.spk.keys():
        spklist.append(key)

    MoeVSConf = {
        "Folder" : f"{path}",
        "Name" : f"{path}",
        "Type" : "SoVits",
        "Rate" : hps.data.sampling_rate,
        "Hop" : hps.data.hop_length,
        "Hubert": f"vec-{SVCVITS.gin_channels}-{vec_lay}",
        "SoVits4": True,
        "SoVits3": False,
        "CharaMix": export_mix,
        "Volume": SVCVITS.vol_embedding,
        "HiddenSize": SVCVITS.gin_channels,
        "Characters": spklist
    }

    with open(f"checkpoints/{path}.json", 'w') as MoeVsConfFile:
        json.dump(MoeVSConf, MoeVsConfFile, indent = 4)


if __name__ == '__main__':
    main()