File size: 2,710 Bytes
dce1ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
import time
import numpy as np
from tqdm import tqdm
import torch

from models.svc.base import SVCInference
from models.svc.vits.vits import SynthesizerTrn

from models.svc.base.svc_dataset import SVCTestDataset, SVCTestCollator
from utils.io import save_audio
from utils.audio_slicer import is_silence


class VitsInference(SVCInference):
    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
        SVCInference.__init__(self, args, cfg)

    def _build_model(self):
        net_g = SynthesizerTrn(
            self.cfg.preprocess.n_fft // 2 + 1,
            self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
            self.cfg,
        )
        self.model = net_g
        return net_g

    def build_save_dir(self, dataset, speaker):
        save_dir = os.path.join(
            self.args.output_dir,
            "svc_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
        )
        if dataset is not None:
            save_dir = os.path.join(save_dir, "data_{}".format(dataset))
        if speaker != -1:
            save_dir = os.path.join(
                save_dir,
                "spk_{}".format(speaker),
            )
        os.makedirs(save_dir, exist_ok=True)
        print("Saving to ", save_dir)
        return save_dir

    @torch.inference_mode()
    def inference(self):
        res = []
        for i, batch in enumerate(self.test_dataloader):
            pred_audio_list = self._inference_each_batch(batch)
            for it, wav in zip(self.test_dataset.metadata, pred_audio_list):
                uid = it["Uid"]
                file = os.path.join(self.args.output_dir, f"{uid}.wav")

                wav = wav.numpy(force=True)
                save_audio(
                    file,
                    wav,
                    self.cfg.preprocess.sample_rate,
                    add_silence=False,
                    turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
                )
                res.append(file)
        return res

    def _inference_each_batch(self, batch_data, noise_scale=0.667):
        device = self.accelerator.device
        pred_res = []
        self.model.eval()
        with torch.no_grad():
            # Put the data to device
            # device = self.accelerator.device
            for k, v in batch_data.items():
                batch_data[k] = v.to(device)

            audios, f0 = self.model.infer(batch_data, noise_scale=noise_scale)

            pred_res.extend(audios)

        return pred_res