''' Author: wuxulong19950206 1287173754@qq.com Date: 2024-03-12 22:44:31 LastEditors: wuxulong19950206 1287173754@qq.com LastEditTime: 2024-03-12 23:05:02 FilePath: \text_to_speech\mtts\models\vocoder\VocGAN\vocgan.py Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE ''' import argparse import glob import os import numpy as np import torch import tqdm from scipy.io.wavfile import write from .denoiser import Denoiser from .model.generator import ModifiedGenerator from .utils.hparams import HParam, load_hparam_str MAX_WAV_VALUE = 32768.0 from .download_utils import download_url url = 'https://zenodo.org/record/4743731/files/vctk_pretrained_model_3180.pt' class VocGan: def __init__(self, device='cuda:0',config=None, denoise=False): # home = os.environ['HOME'] checkpoint_path = config["checkpoint"] denoise = config["denoise"] device = config["device"] # checkpoint_path = os.path.join(home,'./.cache/vocgan') os.makedirs(checkpoint_path,exist_ok=True) checkpoint_file = os.path.join(checkpoint_path,'vctk_pretrained_model_3180.pt') if not os.path.exists(checkpoint_file): download_url(url,checkpoint_path) config = None checkpoint = torch.load(checkpoint_file,map_location=device) if config is not None: hp = HParam(config) else: hp = load_hparam_str(checkpoint['hp_str']) self.hp = hp self.model = ModifiedGenerator(hp.audio.n_mel_channels, hp.model.n_residual_layers, ratios=hp.model.generator_ratio, mult=hp.model.mult, out_band=hp.model.out_channels).to(device) self.model.load_state_dict(checkpoint['model_g']) self.model.eval(inference=True) self.model = self.model.to(device) self.denoise = denoise self.device = device def synthesize(self, mel): with torch.no_grad(): if not isinstance(mel,torch.Tensor): mel = torch.tensor(mel) if len(mel.shape) == 2: mel = mel.unsqueeze(0) mel = mel.to(self.device) audio = self.model.inference(mel) audio = audio.squeeze(0) # collapse all dimension except time axis if self.denoise: denoiser = Denoiser(self.model,device=self.device) #.to(self.device) audio = denoiser(audio, 0.01) audio = audio.squeeze() audio = audio[:-(self.hp.audio.hop_length * 10)] #audio = MAX_WAV_VALUE * audio #audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE - 1) #audio = audio.short() audio = audio.cpu().detach().numpy() return audio def __call__(self,mel): return self.synthesize(mel)