wuxulong19950206
First model version
14d1720
from __future__ import absolute_import, division, print_function, unicode_literals
import glob
import os
import numpy as np
import argparse
import json
import torch
from scipy.io.wavfile import write
from .env import AttrDict
from .meldataset import MAX_WAV_VALUE
from .models import Generator
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print("Loading '{}'".format(filepath))
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '*')
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return ''
return sorted(cp_list)[-1]
class HiFiGAN:
def __init__(self,checkpoint:os.PathLike, h=None, device='cuda'):
config_file = os.path.join(os.path.split(checkpoint)[0], 'config.json')
with open(config_file) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
self.generator = Generator(h).to(device)
state_dict_g = load_checkpoint(checkpoint, device)
self.generator.load_state_dict(state_dict_g['generator'])
self.generator.eval()
self.generator.remove_weight_norm()
self.device = device
def inference(self,x):
with torch.no_grad():
if isinstance(x,np.ndarray):
x = torch.FloatTensor(x).to(self.device)
else:
x = x.to(self.device)
y_g_hat = self.generator(x.unsqueeze(0))
audio = y_g_hat.squeeze()
audio = audio.cpu().numpy()
return audio
def __call__(self,x):
return self.inference(x)