RUSH-miaomi commited on
Commit
b4a5b14
·
1 Parent(s): 05fa709

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -27
app.py CHANGED
@@ -28,6 +28,12 @@ import webbrowser
28
 
29
  net_g = None
30
 
 
 
 
 
 
 
31
 
32
  def get_text(text, language_str, hps):
33
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
@@ -40,53 +46,128 @@ def get_text(text, language_str, hps):
40
  for i in range(len(word2ph)):
41
  word2ph[i] = word2ph[i] * 2
42
  word2ph[0] += 1
43
- bert = get_bert(norm_text, word2ph, language_str)
44
  del word2ph
45
-
46
- assert bert.shape[-1] == len(phone)
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  phone = torch.LongTensor(phone)
49
  tone = torch.LongTensor(tone)
50
  language = torch.LongTensor(language)
 
51
 
52
- return bert, phone, tone, language
53
-
54
- def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
55
  global net_g
56
- bert, phones, tones, lang_ids = get_text(text, "ZH", hps)
57
  with torch.no_grad():
58
- x_tst=phones.to(device).unsqueeze(0)
59
- tones=tones.to(device).unsqueeze(0)
60
- lang_ids=lang_ids.to(device).unsqueeze(0)
61
  bert = bert.to(device).unsqueeze(0)
 
62
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
63
  del phones
64
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
65
- audio = net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, sdp_ratio=sdp_ratio
66
- , noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
 
68
  return audio
69
 
70
- def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
 
 
71
  with torch.no_grad():
72
- audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker)
73
- return "Success", (hps.data.sampling_rate, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
  if __name__ == "__main__":
77
  parser = argparse.ArgumentParser()
78
- parser.add_argument("--model_dir", default="./logs/maolei/G_4800.pth", help="path of your model")
79
- parser.add_argument("--config_dir", default="./configs/config.json", help="path of your config file")
80
- parser.add_argument("--share", default=False, help="make link public")
81
- parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  args = parser.parse_args()
84
  if args.debug:
85
  logger.info("Enable DEBUG-LEVEL log")
86
  logging.basicConfig(level=logging.DEBUG)
87
- hps = utils.get_hparams_from_file(args.config_dir)
88
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
89
- '''
90
  device = (
91
  "cuda:0"
92
  if torch.cuda.is_available()
@@ -96,16 +177,16 @@ if __name__ == "__main__":
96
  else "cpu"
97
  )
98
  )
99
- '''
100
  net_g = SynthesizerTrn(
101
  len(symbols),
102
  hps.data.filter_length // 2 + 1,
103
  hps.train.segment_size // hps.data.hop_length,
104
  n_speakers=hps.data.n_speakers,
105
- **hps.model).to(device)
 
106
  _ = net_g.eval()
107
 
108
- _ = utils.load_checkpoint(args.model_dir, net_g, None, skip_optimizer=True)
109
 
110
  speaker_ids = hps.data.spk2id
111
  speakers = list(speaker_ids.keys())
@@ -141,7 +222,6 @@ if __name__ == "__main__":
141
  outputs=[text_output, audio_output],
142
  )
143
 
144
- # webbrowser.open("http://127.0.0.1:6006")
145
- # app.launch(server_port=6006, show_error=True)
146
 
147
  app.launch(show_error=True)
 
28
 
29
  net_g = None
30
 
31
+ if sys.platform == "darwin" and torch.backends.mps.is_available():
32
+ device = "mps"
33
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
34
+ else:
35
+ device = "cuda"
36
+
37
 
38
  def get_text(text, language_str, hps):
39
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
 
46
  for i in range(len(word2ph)):
47
  word2ph[i] = word2ph[i] * 2
48
  word2ph[0] += 1
49
+ bert = get_bert(norm_text, word2ph, language_str, device)
50
  del word2ph
51
+ assert bert.shape[-1] == len(phone), phone
52
+
53
+ if language_str == "ZH":
54
+ bert = bert
55
+ ja_bert = torch.zeros(768, len(phone))
56
+ elif language_str == "JP":
57
+ ja_bert = bert
58
+ bert = torch.zeros(1024, len(phone))
59
+ else:
60
+ bert = torch.zeros(1024, len(phone))
61
+ ja_bert = torch.zeros(768, len(phone))
62
+
63
+ assert bert.shape[-1] == len(
64
+ phone
65
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
66
 
67
  phone = torch.LongTensor(phone)
68
  tone = torch.LongTensor(tone)
69
  language = torch.LongTensor(language)
70
+ return bert, ja_bert, phone, tone, language
71
 
72
+ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language):
 
 
73
  global net_g
74
+ bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps)
75
  with torch.no_grad():
76
+ x_tst = phones.to(device).unsqueeze(0)
77
+ tones = tones.to(device).unsqueeze(0)
78
+ lang_ids = lang_ids.to(device).unsqueeze(0)
79
  bert = bert.to(device).unsqueeze(0)
80
+ ja_bert = ja_bert.to(device).unsqueeze(0)
81
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
82
  del phones
83
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
84
+ audio = (
85
+ net_g.infer(
86
+ x_tst,
87
+ x_tst_lengths,
88
+ speakers,
89
+ tones,
90
+ lang_ids,
91
+ bert,
92
+ ja_bert,
93
+ sdp_ratio=sdp_ratio,
94
+ noise_scale=noise_scale,
95
+ noise_scale_w=noise_scale_w,
96
+ length_scale=length_scale,
97
+ )[0][0, 0]
98
+ .data.cpu()
99
+ .float()
100
+ .numpy()
101
+ )
102
  del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
103
+ torch.cuda.empty_cache()
104
  return audio
105
 
106
+ def generate_audio(slices, sdp_ratio, noise_scale, noise_scale_w, length_scale, speaker, language):
107
+ audio_list = []
108
+ silence = np.zeros(hps.data.sampling_rate // 2)
109
  with torch.no_grad():
110
+ for piece in slices:
111
+ audio = infer(
112
+ piece,
113
+ sdp_ratio=sdp_ratio,
114
+ noise_scale=noise_scale,
115
+ noise_scale_w=noise_scale_w,
116
+ length_scale=length_scale,
117
+ sid=speaker,
118
+ language=language,
119
+ )
120
+ audio_list.append(audio)
121
+ audio_list.append(silence) # 将静音添加到列表中
122
+ return audio_list
123
+
124
+ def tts_fn(text: str, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language):
125
+ audio_list = []
126
+ if language == "mix":
127
+ bool_valid, str_valid = re_matching.validate_text(text)
128
+ if not bool_valid:
129
+ return str_valid, (hps.data.sampling_rate, np.concatenate([np.zeros(hps.data.sampling_rate // 2)]))
130
+ result = re_matching.text_matching(text)
131
+ for one in result:
132
+ _speaker = one.pop()
133
+ for lang, content in one:
134
+ audio_list.extend(
135
+ generate_audio(content.split("|"), sdp_ratio, noise_scale,
136
+ noise_scale_w, length_scale, _speaker+'_'+lang.lower(), lang)
137
+ )
138
+ else:
139
+ audio_list.extend(
140
+ generate_audio(text.split("|"), sdp_ratio, noise_scale, noise_scale_w, length_scale, speaker, language)
141
+ )
142
+
143
+ audio_concat = np.concatenate(audio_list)
144
+ return "Success", (hps.data.sampling_rate, audio_concat)
145
 
146
 
147
  if __name__ == "__main__":
148
  parser = argparse.ArgumentParser()
149
+ parser.add_argument(
150
+ "-m", "--model", default="./logs/maolei/G_4800.pth", help="path of your model"
151
+ )
152
+ parser.add_argument(
153
+ "-c",
154
+ "--config",
155
+ default="./configs/config.json",
156
+ help="path of your config file",
157
+ )
158
+ parser.add_argument(
159
+ "--share", default=False, help="make link public", action="store_true"
160
+ )
161
+ parser.add_argument(
162
+ "-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log"
163
+ )
164
 
165
  args = parser.parse_args()
166
  if args.debug:
167
  logger.info("Enable DEBUG-LEVEL log")
168
  logging.basicConfig(level=logging.DEBUG)
169
+ hps = utils.get_hparams_from_file(args.config)
170
+
 
171
  device = (
172
  "cuda:0"
173
  if torch.cuda.is_available()
 
177
  else "cpu"
178
  )
179
  )
 
180
  net_g = SynthesizerTrn(
181
  len(symbols),
182
  hps.data.filter_length // 2 + 1,
183
  hps.train.segment_size // hps.data.hop_length,
184
  n_speakers=hps.data.n_speakers,
185
+ **hps.model,
186
+ ).to(device)
187
  _ = net_g.eval()
188
 
189
+ _ = utils.load_checkpoint(args.model, net_g, None, skip_optimizer=True)
190
 
191
  speaker_ids = hps.data.spk2id
192
  speakers = list(speaker_ids.keys())
 
222
  outputs=[text_output, audio_output],
223
  )
224
 
225
+
 
226
 
227
  app.launch(show_error=True)