File size: 5,945 Bytes
c02bdcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import random
from typing import Optional
from time import sleep
import gradio as gr
from tools.audio import float_to_int16, has_ffmpeg_installed, load_audio
from tools.logger import get_logger
logger = get_logger(" WebUI ")
from tools.seeder import TorchSeedContext
from tools.normalizer import normalizer_en_nemo_text, normalizer_zh_tn
import ChatTTS
chat = ChatTTS.Chat(get_logger("ChatTTS"))
custom_path: Optional[str] = None
has_interrupted = False
is_in_generate = False
seed_min = 1
seed_max = 4294967295
use_mp3 = has_ffmpeg_installed()
if not use_mp3:
logger.warning("no ffmpeg installed, use wav file output")
# 音色选项:用于预置合适的音色
voices = {
"Default": {"seed": 2},
"Timbre1": {"seed": 1111},
"Timbre2": {"seed": 2222},
"Timbre3": {"seed": 3333},
"Timbre4": {"seed": 4444},
"Timbre5": {"seed": 5555},
"Timbre6": {"seed": 6666},
"Timbre7": {"seed": 7777},
"Timbre8": {"seed": 8888},
"Timbre9": {"seed": 9999},
}
def generate_seed():
return gr.update(value=random.randint(seed_min, seed_max))
# 返回选择音色对应的seed
def on_voice_change(vocie_selection):
return voices.get(vocie_selection)["seed"]
def on_audio_seed_change(audio_seed_input):
with TorchSeedContext(audio_seed_input):
rand_spk = chat.sample_random_speaker()
return rand_spk
def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
if cust_path == None:
ret = chat.load(coef=coef)
else:
logger.info("local model path: %s", cust_path)
ret = chat.load("custom", custom_path=cust_path, coef=coef)
global custom_path
custom_path = cust_path
if ret:
try:
chat.normalizer.register("en", normalizer_en_nemo_text())
except ValueError as e:
logger.error(e)
except:
logger.warning("Package nemo_text_processing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
)
try:
chat.normalizer.register("zh", normalizer_zh_tn())
except ValueError as e:
logger.error(e)
except:
logger.warning("Package WeTextProcessing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
)
return ret
def reload_chat(coef: Optional[str]) -> str:
global is_in_generate
if is_in_generate:
gr.Warning("Cannot reload when generating!")
return coef
chat.unload()
gr.Info("Model unloaded.")
if len(coef) != 230:
gr.Warning("Ingore invalid DVAE coefficient.")
coef = None
try:
global custom_path
ret = load_chat(custom_path, coef)
except Exception as e:
raise gr.Error(str(e))
if not ret:
raise gr.Error("Unable to load model.")
gr.Info("Reload succeess.")
return chat.coef
def on_upload_sample_audio(sample_audio_input: Optional[str]) -> str:
if sample_audio_input is None:
return ""
sample_audio = load_audio(sample_audio_input, 24000)
spk_smp = chat.sample_audio_speaker(sample_audio)
del sample_audio
return spk_smp
def _set_generate_buttons(generate_button, interrupt_button, is_reset=False):
return gr.update(
value=generate_button, visible=is_reset, interactive=is_reset
), gr.update(value=interrupt_button, visible=not is_reset, interactive=not is_reset)
def refine_text(
text,
text_seed_input,
refine_text_flag,
temperature,
top_P,
top_K,
):
global chat
if not refine_text_flag:
sleep(1) # to skip fast answer of loading mark
return text
text = chat.infer(
text,
skip_refine_text=False,
refine_text_only=True,
params_refine_text=ChatTTS.Chat.RefineTextParams(
temperature=temperature,
top_P=top_P,
top_K=top_K,
manual_seed=text_seed_input,
),
)
return text[0] if isinstance(text, list) else text
def generate_audio(
text,
temperature,
top_P,
top_K,
spk_emb_text: str,
stream,
audio_seed_input,
sample_text_input,
sample_audio_code_input,
):
global chat, has_interrupted
if not text or has_interrupted or not spk_emb_text.startswith("蘁淰"):
return None
params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=spk_emb_text,
temperature=temperature,
top_P=top_P,
top_K=top_K,
manual_seed=audio_seed_input,
)
if sample_text_input and sample_audio_code_input:
params_infer_code.txt_smp = sample_text_input
params_infer_code.spk_smp = sample_audio_code_input
params_infer_code.spk_emb = None
wav = chat.infer(
text,
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
)
if stream:
for gen in wav:
audio = gen[0]
if audio is not None and len(audio) > 0:
yield 24000, float_to_int16(audio).T
del audio
else:
yield 24000, float_to_int16(wav[0]).T
def interrupt_generate():
global chat, has_interrupted
has_interrupted = True
chat.interrupt()
def set_buttons_before_generate(generate_button, interrupt_button):
global has_interrupted, is_in_generate
has_interrupted = False
is_in_generate = True
return _set_generate_buttons(
generate_button,
interrupt_button,
)
def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
global has_interrupted, is_in_generate
is_in_generate = False
return _set_generate_buttons(
generate_button,
interrupt_button,
audio_output is not None or has_interrupted,
)
|