File size: 1,144 Bytes
14d1720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import glob
import tqdm
import torch
import argparse
from scipy.io.wavfile import write
from torch import Tensor
from .model.generator import Generator
from .utils.hparams import HParam, load_hparam_str

MAX_WAV_VALUE = 32768.0

class MelGAN:
    def __init__(self,checkpoint,device='cpu',config=None):
        checkpoint_path = os.path.expanduser(checkpoint)
        config = os.path.expanduser(config)
        
        ckpt = torch.load(checkpoint_path)
        if config is not None:
            hp = HParam(config)
        else:
            hp = load_hparam_str(ckpt['hp_str'])

        self.model = Generator(hp.audio.n_mel_channels).to(device)
        self.model.remove_weight_norm()
        self.device = device
        self.model.load_state_dict(ckpt)
        self.model.eval(inference=False)
    @torch.no_grad()
    def synthesize(self,mel:Tensor):
        if len(mel.shape) == 2:
            mel = mel.unsqueeze(0)
        mel = mel.to(self.device)

        audio = self.model.inference(mel)
        audio = audio.cpu().detach().numpy()
        return audio
    def __call__(self,mel):
        return self.synthesize(mel)