Spaces:
Runtime error
Runtime error
Create frontend.py
Browse files- frontend.py +215 -0
frontend.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import partial
|
15 |
+
from typing import Generator
|
16 |
+
import json
|
17 |
+
import onnxruntime
|
18 |
+
import torch
|
19 |
+
import numpy as np
|
20 |
+
import whisper
|
21 |
+
from typing import Callable
|
22 |
+
import torchaudio.compliance.kaldi as kaldi
|
23 |
+
import torchaudio
|
24 |
+
import os
|
25 |
+
import re
|
26 |
+
import inflect
|
27 |
+
try:
|
28 |
+
import ttsfrd
|
29 |
+
use_ttsfrd = True
|
30 |
+
except ImportError:
|
31 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
32 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
33 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
34 |
+
use_ttsfrd = False
|
35 |
+
from cosyvoice.utils.file_utils import logging
|
36 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
37 |
+
|
38 |
+
|
39 |
+
class CosyVoiceFrontEnd:
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
get_tokenizer: Callable,
|
43 |
+
feat_extractor: Callable,
|
44 |
+
campplus_model: str,
|
45 |
+
speech_tokenizer_model: str,
|
46 |
+
spk2info: str = '',
|
47 |
+
allowed_special: str = 'all'):
|
48 |
+
self.tokenizer = get_tokenizer()
|
49 |
+
self.feat_extractor = feat_extractor
|
50 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
51 |
+
option = onnxruntime.SessionOptions()
|
52 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
53 |
+
option.intra_op_num_threads = 1
|
54 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
55 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
56 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
57 |
+
"CPUExecutionProvider"])
|
58 |
+
if os.path.exists(spk2info):
|
59 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
60 |
+
else:
|
61 |
+
self.spk2info = {}
|
62 |
+
self.allowed_special = allowed_special
|
63 |
+
self.use_ttsfrd = use_ttsfrd
|
64 |
+
if self.use_ttsfrd:
|
65 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
66 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
67 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
68 |
+
'failed to initialize ttsfrd resource'
|
69 |
+
self.frd.set_lang_type('pinyinvg')
|
70 |
+
else:
|
71 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
72 |
+
self.en_tn_model = EnNormalizer()
|
73 |
+
self.inflect_parser = inflect.engine()
|
74 |
+
|
75 |
+
def _extract_text_token(self, text):
|
76 |
+
if isinstance(text, Generator):
|
77 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
78 |
+
# NOTE add a dummy text_token_len for compatibility
|
79 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
80 |
+
else:
|
81 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
82 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
83 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
84 |
+
return text_token, text_token_len
|
85 |
+
|
86 |
+
def _extract_text_token_generator(self, text_generator):
|
87 |
+
for text in text_generator:
|
88 |
+
text_token, _ = self._extract_text_token(text)
|
89 |
+
for i in range(text_token.shape[1]):
|
90 |
+
yield text_token[:, i: i + 1]
|
91 |
+
|
92 |
+
def _extract_speech_token(self, speech):
|
93 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
94 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
95 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
96 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
97 |
+
feat.detach().cpu().numpy(),
|
98 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
99 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
100 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
101 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
102 |
+
return speech_token, speech_token_len
|
103 |
+
|
104 |
+
def _extract_spk_embedding(self, speech):
|
105 |
+
feat = kaldi.fbank(speech,
|
106 |
+
num_mel_bins=80,
|
107 |
+
dither=0,
|
108 |
+
sample_frequency=16000)
|
109 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
110 |
+
embedding = self.campplus_session.run(None,
|
111 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
112 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
113 |
+
return embedding
|
114 |
+
|
115 |
+
def _extract_speech_feat(self, speech):
|
116 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
117 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
118 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
119 |
+
return speech_feat, speech_feat_len
|
120 |
+
|
121 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
122 |
+
if isinstance(text, Generator):
|
123 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
124 |
+
return [text]
|
125 |
+
if text_frontend is False or text == '':
|
126 |
+
return [text] if split is True else text
|
127 |
+
text = text.strip()
|
128 |
+
if self.use_ttsfrd:
|
129 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
130 |
+
text = ''.join(texts)
|
131 |
+
else:
|
132 |
+
if contains_chinese(text):
|
133 |
+
text = self.zh_tn_model.normalize(text)
|
134 |
+
text = text.replace("\n", "")
|
135 |
+
text = replace_blank(text)
|
136 |
+
text = replace_corner_mark(text)
|
137 |
+
text = text.replace(".", "。")
|
138 |
+
text = text.replace(" - ", ",")
|
139 |
+
text = remove_bracket(text)
|
140 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
141 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
142 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
143 |
+
else:
|
144 |
+
text = self.en_tn_model.normalize(text)
|
145 |
+
text = spell_out_number(text, self.inflect_parser)
|
146 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
147 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
148 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
149 |
+
return texts if split is True else text
|
150 |
+
|
151 |
+
def frontend_sft(self, tts_text, spk_id):
|
152 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
153 |
+
embedding = self.spk2info[spk_id]['embedding']
|
154 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
155 |
+
return model_input
|
156 |
+
|
157 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
158 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
159 |
+
if zero_shot_spk_id == '':
|
160 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
161 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
162 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
163 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
164 |
+
if resample_rate == 24000:
|
165 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
166 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
167 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
168 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
169 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
170 |
+
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
171 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
172 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
173 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
174 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
175 |
+
else:
|
176 |
+
model_input = self.spk2info[zero_shot_spk_id]
|
177 |
+
model_input['text'] = tts_text_token
|
178 |
+
model_input['text_len'] = tts_text_token_len
|
179 |
+
return model_input
|
180 |
+
|
181 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
182 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
183 |
+
# in cross lingual mode, we remove prompt in llm
|
184 |
+
del model_input['prompt_text']
|
185 |
+
del model_input['prompt_text_len']
|
186 |
+
del model_input['llm_prompt_speech_token']
|
187 |
+
del model_input['llm_prompt_speech_token_len']
|
188 |
+
return model_input
|
189 |
+
|
190 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
191 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
192 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
193 |
+
del model_input['llm_embedding']
|
194 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
195 |
+
model_input['prompt_text'] = instruct_text_token
|
196 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
197 |
+
return model_input
|
198 |
+
|
199 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
200 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
201 |
+
del model_input['llm_prompt_speech_token']
|
202 |
+
del model_input['llm_prompt_speech_token_len']
|
203 |
+
return model_input
|
204 |
+
|
205 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
206 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
207 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
208 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
209 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
210 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
211 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
212 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
213 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
214 |
+
'flow_embedding': embedding}
|
215 |
+
return model_input
|