Spaces:
Sleeping
Sleeping
from __future__ import absolute_import, division, print_function, unicode_literals | |
import glob | |
import os | |
import numpy as np | |
import argparse | |
import json | |
import torch | |
from scipy.io.wavfile import write | |
from .env import AttrDict | |
from .meldataset import MAX_WAV_VALUE | |
from .models import Generator | |
def load_checkpoint(filepath, device): | |
assert os.path.isfile(filepath) | |
print("Loading '{}'".format(filepath)) | |
checkpoint_dict = torch.load(filepath, map_location=device) | |
print("Complete.") | |
return checkpoint_dict | |
def scan_checkpoint(cp_dir, prefix): | |
pattern = os.path.join(cp_dir, prefix + '*') | |
cp_list = glob.glob(pattern) | |
if len(cp_list) == 0: | |
return '' | |
return sorted(cp_list)[-1] | |
class HiFiGAN: | |
def __init__(self,checkpoint:os.PathLike, h=None, device='cuda'): | |
config_file = os.path.join(os.path.split(checkpoint)[0], 'config.json') | |
with open(config_file) as f: | |
data = f.read() | |
json_config = json.loads(data) | |
h = AttrDict(json_config) | |
self.generator = Generator(h).to(device) | |
state_dict_g = load_checkpoint(checkpoint, device) | |
self.generator.load_state_dict(state_dict_g['generator']) | |
self.generator.eval() | |
self.generator.remove_weight_norm() | |
self.device = device | |
def inference(self,x): | |
with torch.no_grad(): | |
if isinstance(x,np.ndarray): | |
x = torch.FloatTensor(x).to(self.device) | |
else: | |
x = x.to(self.device) | |
y_g_hat = self.generator(x.unsqueeze(0)) | |
audio = y_g_hat.squeeze() | |
audio = audio.cpu().numpy() | |
return audio | |
def __call__(self,x): | |
return self.inference(x) | |