Mahiruoshi commited on
Commit
994e4b7
·
1 Parent(s): e829f7e

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +177 -118
server.py CHANGED
@@ -1,25 +1,72 @@
1
- from flask import Flask, request, Response
2
- from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
- from av import open as avopen
 
 
 
 
 
 
 
 
 
 
 
5
 
 
6
  import commons
 
 
 
7
  import utils
 
8
  from models import SynthesizerTrn
9
  from text.symbols import symbols
10
- from text import cleaned_text_to_sequence, get_bert
11
- from text.cleaner import clean_text
12
- from scipy.io import wavfile
13
 
14
- # Flask Init
15
- app = Flask(__name__)
16
- app.config["JSON_AS_ASCII"] = False
17
 
 
18
 
19
- def get_text(text, language_str, hps):
 
 
 
 
 
 
 
 
 
 
 
 
20
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
21
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
22
-
23
  if hps.data.add_blank:
24
  phone = commons.intersperse(phone, 0)
25
  tone = commons.intersperse(tone, 0)
@@ -27,38 +74,83 @@ def get_text(text, language_str, hps):
27
  for i in range(len(word2ph)):
28
  word2ph[i] = word2ph[i] * 2
29
  word2ph[0] += 1
30
- bert = get_bert(norm_text, word2ph, language_str)
31
  del word2ph
32
- assert bert.shape[-1] == len(phone), phone
33
 
34
  if language_str == "ZH":
35
- bert = bert
36
- ja_bert = torch.zeros(768, len(phone))
37
- elif language_str == "JA":
38
- ja_bert = bert
39
  bert = torch.zeros(1024, len(phone))
40
- else:
 
 
41
  bert = torch.zeros(1024, len(phone))
42
- ja_bert = torch.zeros(768, len(phone))
 
 
 
 
43
  assert bert.shape[-1] == len(
44
  phone
45
  ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
 
46
  phone = torch.LongTensor(phone)
47
  tone = torch.LongTensor(tone)
48
  language = torch.LongTensor(language)
49
- return bert, ja_bert, phone, tone, language
50
 
 
51
 
52
- def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language):
53
- bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  with torch.no_grad():
55
- x_tst = phones.to(dev).unsqueeze(0)
56
- tones = tones.to(dev).unsqueeze(0)
57
- lang_ids = lang_ids.to(dev).unsqueeze(0)
58
- bert = bert.to(dev).unsqueeze(0)
59
  ja_bert = ja_bert.to(device).unsqueeze(0)
60
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(dev)
61
- speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(dev)
 
 
 
 
62
  audio = (
63
  net_g.infer(
64
  x_tst,
@@ -68,6 +160,8 @@ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, langua
68
  lang_ids,
69
  bert,
70
  ja_bert,
 
 
71
  sdp_ratio=sdp_ratio,
72
  noise_scale=noise_scale,
73
  noise_scale_w=noise_scale_w,
@@ -77,94 +171,59 @@ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, langua
77
  .float()
78
  .numpy()
79
  )
80
- return audio
81
-
82
-
83
- def replace_punctuation(text, i=2):
84
- punctuation = ",。?!"
85
- for char in punctuation:
86
- text = text.replace(char, char * i)
87
- return text
88
-
89
-
90
- def wav2(i, o, format):
91
- inp = avopen(i, "rb")
92
- out = avopen(o, "wb", format=format)
93
- if format == "ogg":
94
- format = "libvorbis"
95
-
96
- ostream = out.add_stream(format)
97
-
98
- for frame in inp.decode(audio=0):
99
- for p in ostream.encode(frame):
100
- out.mux(p)
101
-
102
- for p in ostream.encode(None):
103
- out.mux(p)
104
-
105
- out.close()
106
- inp.close()
107
-
108
-
109
- # Load Generator
110
- hps = utils.get_hparams_from_file("./configs/config.json")
111
-
112
- dev = "cuda"
113
- net_g = SynthesizerTrn(
114
- len(symbols),
115
- hps.data.filter_length // 2 + 1,
116
- hps.train.segment_size // hps.data.hop_length,
117
- n_speakers=hps.data.n_speakers,
118
- **hps.model,
119
- ).to(dev)
120
- _ = net_g.eval()
121
-
122
- _ = utils.load_checkpoint("logs/G_649000.pth", net_g, None, skip_optimizer=True)
123
-
124
-
125
- @app.route("/")
126
- def main():
127
- try:
128
- speaker = request.args.get("speaker")
129
- text = request.args.get("text").replace("/n", "")
130
- sdp_ratio = float(request.args.get("sdp_ratio", 0.2))
131
- noise = float(request.args.get("noise", 0.5))
132
- noisew = float(request.args.get("noisew", 0.6))
133
- length = float(request.args.get("length", 1.2))
134
- language = request.args.get("language")
135
- if length >= 2:
136
- return "Too big length"
137
- if len(text) >= 250:
138
- return "Too long text"
139
- fmt = request.args.get("format", "wav")
140
- if None in (speaker, text):
141
- return "Missing Parameter"
142
- if fmt not in ("mp3", "wav", "ogg"):
143
- return "Invalid Format"
144
- if language not in ("JA", "ZH"):
145
- return "Invalid language"
146
- except:
147
- return "Invalid Parameter"
148
-
149
- with torch.no_grad():
150
- audio = infer(
151
- text,
152
- sdp_ratio=sdp_ratio,
153
- noise_scale=noise,
154
- noise_scale_w=noisew,
155
- length_scale=length,
156
- sid=speaker,
157
- language=language,
158
- )
159
 
160
- with BytesIO() as wav:
161
- wavfile.write(wav, hps.data.sampling_rate, audio)
162
- torch.cuda.empty_cache()
163
- if fmt == "wav":
164
- return Response(wav.getvalue(), mimetype="audio/wav")
165
- wav.seek(0, 0)
166
- with BytesIO() as ofp:
167
- wav2(wav, ofp, fmt)
168
- return Response(
169
- ofp.getvalue(), mimetype="audio/mpeg" if fmt == "mp3" else "audio/ogg"
170
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import logging
6
+ import re_matching
7
+
8
+ from flask import Flask, request, jsonify
9
+ from flask_cors import CORS
10
+
11
+ logging.getLogger("numba").setLevel(logging.WARNING)
12
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
13
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
14
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
15
+
16
+ logging.basicConfig(
17
+ level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+ import librosa
22
+ import numpy as np
23
  import torch
24
+ import torch.nn as nn
25
+ from torch.utils.data import Dataset
26
+ from torch.utils.data import DataLoader, Dataset
27
+ from tqdm import tqdm
28
+ from transformers import Wav2Vec2Processor
29
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
30
+ Wav2Vec2Model,
31
+ Wav2Vec2PreTrainedModel,
32
+ )
33
+
34
+ import utils
35
+ from config import config
36
 
37
+ import torch
38
  import commons
39
+ from text import cleaned_text_to_sequence, get_bert
40
+ from emo_gen import process_func, EmotionModel, Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2PreTrainedModel, RegressionHead
41
+ from text.cleaner import clean_text
42
  import utils
43
+
44
  from models import SynthesizerTrn
45
  from text.symbols import symbols
46
+ import sys
 
 
47
 
48
+ from scipy.io.wavfile import write
49
+
50
+ net_g = None
51
 
52
+ device = 'cpu'
53
 
54
+ def get_net_g(model_path: str, version: str, device: str, hps):
55
+ net_g = SynthesizerTrn(
56
+ len(symbols),
57
+ hps.data.filter_length // 2 + 1,
58
+ hps.train.segment_size // hps.data.hop_length,
59
+ n_speakers=hps.data.n_speakers,
60
+ **hps.model,
61
+ ).to(device)
62
+ _ = net_g.eval()
63
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
64
+ return net_g
65
+
66
+ def get_text(text, language_str, hps, device):
67
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
68
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
69
+ #print(text)
70
  if hps.data.add_blank:
71
  phone = commons.intersperse(phone, 0)
72
  tone = commons.intersperse(tone, 0)
 
74
  for i in range(len(word2ph)):
75
  word2ph[i] = word2ph[i] * 2
76
  word2ph[0] += 1
77
+ bert_ori = get_bert(norm_text, word2ph, language_str, device)
78
  del word2ph
79
+ assert bert_ori.shape[-1] == len(phone), phone
80
 
81
  if language_str == "ZH":
82
+ bert = bert_ori
83
+ ja_bert = torch.zeros(1024, len(phone))
84
+ en_bert = torch.zeros(1024, len(phone))
85
+ elif language_str == "JP":
86
  bert = torch.zeros(1024, len(phone))
87
+ ja_bert = bert_ori
88
+ en_bert = torch.zeros(1024, len(phone))
89
+ elif language_str == "EN":
90
  bert = torch.zeros(1024, len(phone))
91
+ ja_bert = torch.zeros(1024, len(phone))
92
+ en_bert = bert_ori
93
+ else:
94
+ raise ValueError("language_str should be ZH, JP or EN")
95
+
96
  assert bert.shape[-1] == len(
97
  phone
98
  ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
99
+
100
  phone = torch.LongTensor(phone)
101
  tone = torch.LongTensor(tone)
102
  language = torch.LongTensor(language)
103
+ return bert, ja_bert, en_bert, phone, tone, language
104
 
105
+ def get_emo_(reference_audio, emotion):
106
 
107
+ if (emotion == 10 and reference_audio):
108
+ emo = torch.from_numpy(get_emo(reference_audio))
109
+ else:
110
+ emo = torch.Tensor([emotion])
111
+
112
+ return emo
113
+
114
+ def get_emo(path):
115
+ wav, sr = librosa.load(path, 16000)
116
+ device = config.bert_gen_config.device
117
+ return process_func(
118
+ np.expand_dims(wav, 0).astype(np.float64),
119
+ sr,
120
+ emotional_model,
121
+ emotional_processor,
122
+ device,
123
+ embeddings=True,
124
+ ).squeeze(0)
125
+
126
+ def infer(
127
+ text,
128
+ sdp_ratio,
129
+ noise_scale,
130
+ noise_scale_w,
131
+ length_scale,
132
+ sid,
133
+ reference_audio=None,
134
+ emotion=0,
135
+ ):
136
+
137
+ language= 'JP' if is_japanese(text) else 'ZH'
138
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
139
+ text, language, hps, device
140
+ )
141
+ emo = get_emo_(reference_audio, emotion)
142
  with torch.no_grad():
143
+ x_tst = phones.to(device).unsqueeze(0)
144
+ tones = tones.to(device).unsqueeze(0)
145
+ lang_ids = lang_ids.to(device).unsqueeze(0)
146
+ bert = bert.to(device).unsqueeze(0)
147
  ja_bert = ja_bert.to(device).unsqueeze(0)
148
+ en_bert = en_bert.to(device).unsqueeze(0)
149
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
150
+ emo = emo.to(device).unsqueeze(0)
151
+ print(emo)
152
+ del phones
153
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
154
  audio = (
155
  net_g.infer(
156
  x_tst,
 
160
  lang_ids,
161
  bert,
162
  ja_bert,
163
+ en_bert,
164
+ emo,
165
  sdp_ratio=sdp_ratio,
166
  noise_scale=noise_scale,
167
  noise_scale_w=noise_scale_w,
 
171
  .float()
172
  .numpy()
173
  )
174
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
175
+ if torch.cuda.is_available():
176
+ torch.cuda.empty_cache()
177
+ write("temp.wav", 44100, audio)
178
+ return 'success'
179
+
180
+ def is_japanese(string):
181
+ for ch in string:
182
+ if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
183
+ return True
184
+ return False
185
+
186
+ def loadmodel(model):
187
+ _ = net_g.eval()
188
+ _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
189
+ return "success"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ app = Flask(__name__)
192
+ CORS(app)
193
+ @app.route('/tts')
194
+
195
+ def tts():
196
+ # 这些没必要改
197
+ speaker = request.args.get('speaker')
198
+ sdp_ratio = float(request.args.get('sdp_ratio', 0.2))
199
+ noise_scale = float(request.args.get('noise_scale', 0.6))
200
+ noise_scale_w = float(request.args.get('noise_scale_w', 0.8))
201
+ length_scale = float(request.args.get('length_scale', 1))
202
+ text = request.args.get('text')
203
+ status = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, reference_audio=None, emotion=0)
204
+ with open('temp.wav','rb') as bit:
205
+ wav_bytes = bit.read()
206
+
207
+ headers = {
208
+ 'Content-Type': 'audio/wav',
209
+ 'Text': status.encode('utf-8')}
210
+ return wav_bytes, 200, headers
211
+
212
+
213
+ if __name__ == "__main__":
214
+ emotional_model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
215
+ REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
216
+ emotional_processor = Wav2Vec2Processor.from_pretrained(emotional_model_name)
217
+ emotional_model = EmotionModel.from_pretrained(emotional_model_name).to(device)
218
+ languages = [ "Auto", "ZH", "JP"]
219
+ modelPaths = []
220
+ for dirpath, dirnames, filenames in os.walk("Data/Bushiroad/models/"):
221
+ for filename in filenames:
222
+ modelPaths.append(os.path.join(dirpath, filename))
223
+ hps = utils.get_hparams_from_file('Data/Bushiroad/configs/config.json')
224
+ net_g = get_net_g(
225
+ model_path=modelPaths[-1], version="2.1", device=device, hps=hps
226
+ )
227
+ speaker_ids = hps.data.spk2id
228
+ speakers = list(speaker_ids.keys())
229
+ app.run(host="0.0.0.0", port=5000)