wuxulong19950206 commited on
Commit
9270314
·
1 Parent(s): 7c59e13
checkpoints/a.txt ADDED
File without changes
examples/biaobei/config.yaml CHANGED
@@ -59,7 +59,7 @@ vocoder:
59
  config: ~/checkpoints/melgan/default.yaml
60
  device: cpu
61
  VocGan:
62
- checkpoint: ~/checkpoints/vctk_pretrained_model_3180.pt #~/checkpoints/ljspeech_29de09d_4000.pt
63
  denoise: True
64
  device: cpu
65
  HiFiGAN:
 
59
  config: ~/checkpoints/melgan/default.yaml
60
  device: cpu
61
  VocGan:
62
+ checkpoint: checkpoints #~/checkpoints/ljspeech_29de09d_4000.pt
63
  denoise: True
64
  device: cpu
65
  HiFiGAN:
input.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ text1|sil ni3 qu4 zuo4 fan4 ba5 sil|sil 你 去 做 饭 吧 sil|0
mtts/models/vocoder/VocGAN/vocgan.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import argparse
2
  import glob
3
  import os
@@ -16,13 +24,16 @@ from .download_utils import download_url
16
  url = 'https://zenodo.org/record/4743731/files/vctk_pretrained_model_3180.pt'
17
  class VocGan:
18
  def __init__(self, device='cuda:0',config=None, denoise=False):
19
- home = os.environ['HOME']
20
- checkpoint_path = os.path.join(home,'./.cache/vocgan')
 
 
 
21
  os.makedirs(checkpoint_path,exist_ok=True)
22
  checkpoint_file = os.path.join(checkpoint_path,'vctk_pretrained_model_3180.pt')
23
  if not os.path.exists(checkpoint_file):
24
  download_url(url,checkpoint_path)
25
-
26
  checkpoint = torch.load(checkpoint_file,map_location=device)
27
  if config is not None:
28
  hp = HParam(config)
 
1
+ '''
2
+ Author: wuxulong19950206 [email protected]
3
+ Date: 2024-03-12 22:44:31
4
+ LastEditors: wuxulong19950206 [email protected]
5
+ LastEditTime: 2024-03-12 23:05:02
6
+ FilePath: \text_to_speech\mtts\models\vocoder\VocGAN\vocgan.py
7
+ Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
8
+ '''
9
  import argparse
10
  import glob
11
  import os
 
24
  url = 'https://zenodo.org/record/4743731/files/vctk_pretrained_model_3180.pt'
25
  class VocGan:
26
  def __init__(self, device='cuda:0',config=None, denoise=False):
27
+ # home = os.environ['HOME']
28
+ checkpoint_path = config["checkpoint"]
29
+ denoise = config["denoise"]
30
+ device = config["device"]
31
+ # checkpoint_path = os.path.join(home,'./.cache/vocgan')
32
  os.makedirs(checkpoint_path,exist_ok=True)
33
  checkpoint_file = os.path.join(checkpoint_path,'vctk_pretrained_model_3180.pt')
34
  if not os.path.exists(checkpoint_file):
35
  download_url(url,checkpoint_path)
36
+ config = None
37
  checkpoint = torch.load(checkpoint_file,map_location=device)
38
  if config is not None:
39
  hp = HParam(config)
synthesize.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import subprocess
4
+
5
+ import numpy as np
6
+ import torch
7
+ import yaml
8
+ from scipy.io import wavfile
9
+
10
+ from mtts.models.fs2_model import FastSpeech2
11
+ from mtts.models.vocoder import *
12
+ from mtts.text import TextProcessor
13
+ from mtts.utils.logging import get_logger
14
+
15
+ logger = get_logger(__file__)
16
+
17
+
18
+ def check_ffmpeg():
19
+ r, path = subprocess.getstatusoutput("which ffmpeg")
20
+ return r == 0
21
+
22
+
23
+ with_ffmpeg = check_ffmpeg()
24
+
25
+
26
+ def build_vocoder(device, config):
27
+ vocoder_name = config['vocoder']['type']
28
+ VocoderClass = eval(vocoder_name)
29
+ model = VocoderClass(config=config['vocoder'][vocoder_name])
30
+ return model
31
+
32
+
33
+ def normalize(wav):
34
+ assert wav.dtype == np.float32
35
+ eps = 1e-6
36
+ sil = wav[1500:2000]
37
+ #wav = wav - np.mean(sil)
38
+ #wav = (wav - np.min(wav))/(np.max(wav)-np.min(wav)+eps)
39
+ wav = wav / np.max(np.abs(wav))
40
+ #wav = wav*2-1
41
+ wav = wav * 32767
42
+ return wav.astype('int16')
43
+
44
+
45
+ def to_int16(wav):
46
+ wav = wav = wav * 32767
47
+ wav = np.clamp(wav, -32767, 32768)
48
+ return wav.astype('int16')
49
+
50
+
51
+ if __name__ == '__main__':
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument('-i', '--input', type=str, default='input.txt')
54
+ parser.add_argument('--duration', type=float, default=1.0)
55
+ parser.add_argument('--output_dir', type=str, default='./outputs/')
56
+ parser.add_argument('--checkpoint', type=str, required=False, default='checkpoints\checkpoint_140000.pth.tar')
57
+ parser.add_argument('-c', '--config', type=str, default='./config.yaml')
58
+ parser.add_argument('-d', '--device', choices=['cuda', 'cpu'], type=str, default='cuda')
59
+ args = parser.parse_args()
60
+
61
+ if not os.path.exists(args.output_dir):
62
+ os.makedirs(args.output_dir)
63
+
64
+ with open(args.config) as f:
65
+ config = yaml.safe_load(f)
66
+ logger.info(f.read())
67
+
68
+ sr = config['fbank']['sample_rate']
69
+
70
+ vocoder = build_vocoder(args.device, config)
71
+ text_processor = TextProcessor(config)
72
+ model = FastSpeech2(config)
73
+
74
+ if args.checkpoint != '':
75
+ sd = torch.load(args.checkpoint, map_location=args.device)
76
+ if 'model' in sd.keys():
77
+ sd = sd['model']
78
+ model.load_state_dict(sd)
79
+ del sd # to save mem
80
+ model = model.to(args.device)
81
+ torch.set_grad_enabled(False)
82
+
83
+ try:
84
+ lines = open(args.input).read().split('\n')
85
+ except:
86
+ print('Failed to open text file', args.input)
87
+ print('Treating input as text')
88
+ lines = [args.input]
89
+
90
+ for line in lines:
91
+ if len(line) == 0 or line.startswith('#'):
92
+ continue
93
+ logger.info(f'processing {line}')
94
+ name, tokens = text_processor(line)
95
+ tokens = tokens.to(args.device)
96
+ seq_len = torch.tensor([tokens.shape[1]])
97
+ tokens = tokens.unsqueeze(1)
98
+ seq_len = seq_len.to(args.device)
99
+ max_src_len = torch.max(seq_len)
100
+ output = model(tokens, seq_len, max_src_len=max_src_len, d_control=args.duration)
101
+ mel_pred, mel_postnet, d_pred, src_mask, mel_mask, mel_len = output
102
+
103
+ # convert to waveform using vocoder
104
+ mel_postnet = mel_postnet[0].transpose(0, 1).detach()
105
+ mel_postnet += config['fbank']['mel_mean']
106
+ wav = vocoder(mel_postnet)
107
+ if config['synthesis']['normalize']:
108
+ wav = normalize(wav)
109
+ else:
110
+ wav = to_int16(wav)
111
+ dst_file = os.path.join(args.output_dir, f'{name}.wav')
112
+ #np.save(dst_file+'.npy',mel_postnet.cpu().numpy())
113
+ logger.info(f'writing file to {dst_file}')
114
+ wavfile.write(dst_file, sr, wav)