keithhon commited on
Commit
6e8c3d6
·
1 Parent(s): 5ed6c8f

Upload vocoder/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocoder/inference.py +64 -0
vocoder/inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vocoder.models.fatchord_version import WaveRNN
2
+ from vocoder import hparams as hp
3
+ import torch
4
+
5
+
6
+ _model = None # type: WaveRNN
7
+
8
+ def load_model(weights_fpath, verbose=True):
9
+ global _model, _device
10
+
11
+ if verbose:
12
+ print("Building Wave-RNN")
13
+ _model = WaveRNN(
14
+ rnn_dims=hp.voc_rnn_dims,
15
+ fc_dims=hp.voc_fc_dims,
16
+ bits=hp.bits,
17
+ pad=hp.voc_pad,
18
+ upsample_factors=hp.voc_upsample_factors,
19
+ feat_dims=hp.num_mels,
20
+ compute_dims=hp.voc_compute_dims,
21
+ res_out_dims=hp.voc_res_out_dims,
22
+ res_blocks=hp.voc_res_blocks,
23
+ hop_length=hp.hop_length,
24
+ sample_rate=hp.sample_rate,
25
+ mode=hp.voc_mode
26
+ )
27
+
28
+ if torch.cuda.is_available():
29
+ _model = _model.cuda()
30
+ _device = torch.device('cuda')
31
+ else:
32
+ _device = torch.device('cpu')
33
+
34
+ if verbose:
35
+ print("Loading model weights at %s" % weights_fpath)
36
+ checkpoint = torch.load(weights_fpath, _device)
37
+ _model.load_state_dict(checkpoint['model_state'])
38
+ _model.eval()
39
+
40
+
41
+ def is_loaded():
42
+ return _model is not None
43
+
44
+
45
+ def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800,
46
+ progress_callback=None):
47
+ """
48
+ Infers the waveform of a mel spectrogram output by the synthesizer (the format must match
49
+ that of the synthesizer!)
50
+
51
+ :param normalize:
52
+ :param batched:
53
+ :param target:
54
+ :param overlap:
55
+ :return:
56
+ """
57
+ if _model is None:
58
+ raise Exception("Please load Wave-RNN in memory before using it")
59
+
60
+ if normalize:
61
+ mel = mel / hp.mel_max_abs_value
62
+ mel = torch.from_numpy(mel[None, ...])
63
+ wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback)
64
+ return wav