Ailyth commited on
Commit
db45ded
·
1 Parent(s): 714116a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -43
app.py CHANGED
@@ -1,19 +1,4 @@
1
  import sys, os
2
-
3
- if sys.platform == "darwin":
4
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
-
6
- import logging
7
-
8
- logging.getLogger("numba").setLevel(logging.WARNING)
9
- logging.getLogger("markdown_it").setLevel(logging.WARNING)
10
- logging.getLogger("urllib3").setLevel(logging.WARNING)
11
- logging.getLogger("matplotlib").setLevel(logging.WARNING)
12
-
13
- logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
  import torch
18
  import argparse
19
  import commons
@@ -27,7 +12,13 @@ import webbrowser
27
  import soundfile as sf
28
  from datetime import datetime
29
  import pytz
30
-
 
 
 
 
 
 
31
 
32
  net_g = None
33
  models = {
@@ -94,24 +85,17 @@ def tts_generator(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_s
94
  global net_g
95
  model_path = models[model]
96
  net_g, _, _, _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
97
- with torch.no_grad():
98
- audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker,model_dir=model)
99
- with open('tmp.wav', 'rb') as wav_file:
100
- mp3 = convert_wav_to_mp3(wav_file)
101
- return "生成语音成功", (hps.data.sampling_rate, audio), mp3
 
 
 
102
 
103
  if __name__ == "__main__":
104
- parser = argparse.ArgumentParser()
105
- parser.add_argument("--model_dir", default="", help="path of your model")
106
- parser.add_argument("--config_dir", default="./configs/config.json", help="path of your config file")
107
- parser.add_argument("--share", default=False, help="make link public")
108
- parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
109
-
110
- args = parser.parse_args()
111
- if args.debug:
112
- logger.info("Enable DEBUG-LEVEL log")
113
- logging.basicConfig(level=logging.DEBUG)
114
- hps = utils.get_hparams_from_file(args.config_dir)
115
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
116
 
117
  net_g = SynthesizerTrn(
@@ -122,7 +106,6 @@ if __name__ == "__main__":
122
  **hps.model).to(device)
123
  _ = net_g.eval()
124
 
125
-
126
  speaker_ids = hps.data.spk2id
127
  speakers = list(speaker_ids.keys())
128
 
@@ -130,12 +113,10 @@ if __name__ == "__main__":
130
  with gr.Row():
131
  with gr.Column():
132
 
133
-
134
- gr.Markdown(value="""
135
- 测试用
136
- """)
137
  text = gr.TextArea(label="Text", placeholder="Input Text Here",
138
- value="在不在?能不能借给我三百块钱买可乐",info="使用huggingface的免费CPU进行推理,因此速度不快,一次性不要输入超过500汉字")
 
139
 
140
  model = gr.Radio(choices=list(models.keys()), value=list(models.keys())[0], label='音声模型')
141
  #model = gr.Dropdown(choices=models,value=models[0], label='音声模型')
@@ -150,12 +131,13 @@ if __name__ == "__main__":
150
  text_output = gr.Textbox(label="Message")
151
  audio_output = gr.Audio(label="试听")
152
  MP3_output = gr.File(label="下载")
153
- gr.Markdown(value="""
154
 
155
  """)
156
- btn.click(tts_generator,
 
157
  inputs=[text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, model],
158
- outputs=[text_output, audio_output,MP3_output])
159
-
160
-
161
  app.launch(show_error=True)
 
1
  import sys, os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import argparse
4
  import commons
 
12
  import soundfile as sf
13
  from datetime import datetime
14
  import pytz
15
+ import logging
16
+ logging.getLogger("numba").setLevel(logging.WARNING)
17
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
18
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
19
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
20
+ logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
21
+ logger = logging.getLogger(__name__)
22
 
23
  net_g = None
24
  models = {
 
85
  global net_g
86
  model_path = models[model]
87
  net_g, _, _, _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
88
+ try:
89
+ with torch.no_grad():
90
+ audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker,model_dir=model)
91
+ with open('tmp.wav', 'rb') as wav_file:
92
+ mp3 = convert_wav_to_mp3(wav_file)
93
+ return "生成语音成功", (hps.data.sampling_rate, audio), mp3
94
+ except Exception as e:
95
+ return "生成语音失败:" + str(e), None, None
96
 
97
  if __name__ == "__main__":
98
+ hps = utils.get_hparams_from_file("./configs/config.json")
 
 
 
 
 
 
 
 
 
 
99
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
100
 
101
  net_g = SynthesizerTrn(
 
106
  **hps.model).to(device)
107
  _ = net_g.eval()
108
 
 
109
  speaker_ids = hps.data.spk2id
110
  speakers = list(speaker_ids.keys())
111
 
 
113
  with gr.Row():
114
  with gr.Column():
115
 
116
+ gr.Markdown("测试用")
 
 
 
117
  text = gr.TextArea(label="Text", placeholder="Input Text Here",
118
+ value="在不在?能不能借给我三百块钱买可乐",
119
+ info="使用huggingface的免费CPU进行推理,因此速度不快,一次性不要输入超过500汉字")
120
 
121
  model = gr.Radio(choices=list(models.keys()), value=list(models.keys())[0], label='音声模型')
122
  #model = gr.Dropdown(choices=models,value=models[0], label='音声模型')
 
131
  text_output = gr.Textbox(label="Message")
132
  audio_output = gr.Audio(label="试听")
133
  MP3_output = gr.File(label="下载")
134
+ gr.Markdown("""
135
 
136
  """)
137
+ btn.click(
138
+ tts_generator,
139
  inputs=[text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, model],
140
+ outputs=[text_output, audio_output,MP3_output]
141
+ )
142
+
143
  app.launch(show_error=True)