ChatTTS2 / examples /web /funcs.py
zhengr's picture
init
c02bdcd
raw
history blame
5.95 kB
import random
from typing import Optional
from time import sleep
import gradio as gr
from tools.audio import float_to_int16, has_ffmpeg_installed, load_audio
from tools.logger import get_logger
logger = get_logger(" WebUI ")
from tools.seeder import TorchSeedContext
from tools.normalizer import normalizer_en_nemo_text, normalizer_zh_tn
import ChatTTS
chat = ChatTTS.Chat(get_logger("ChatTTS"))
custom_path: Optional[str] = None
has_interrupted = False
is_in_generate = False
seed_min = 1
seed_max = 4294967295
use_mp3 = has_ffmpeg_installed()
if not use_mp3:
logger.warning("no ffmpeg installed, use wav file output")
# 音色选项:用于预置合适的音色
voices = {
"Default": {"seed": 2},
"Timbre1": {"seed": 1111},
"Timbre2": {"seed": 2222},
"Timbre3": {"seed": 3333},
"Timbre4": {"seed": 4444},
"Timbre5": {"seed": 5555},
"Timbre6": {"seed": 6666},
"Timbre7": {"seed": 7777},
"Timbre8": {"seed": 8888},
"Timbre9": {"seed": 9999},
}
def generate_seed():
return gr.update(value=random.randint(seed_min, seed_max))
# 返回选择音色对应的seed
def on_voice_change(vocie_selection):
return voices.get(vocie_selection)["seed"]
def on_audio_seed_change(audio_seed_input):
with TorchSeedContext(audio_seed_input):
rand_spk = chat.sample_random_speaker()
return rand_spk
def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
if cust_path == None:
ret = chat.load(coef=coef)
else:
logger.info("local model path: %s", cust_path)
ret = chat.load("custom", custom_path=cust_path, coef=coef)
global custom_path
custom_path = cust_path
if ret:
try:
chat.normalizer.register("en", normalizer_en_nemo_text())
except ValueError as e:
logger.error(e)
except:
logger.warning("Package nemo_text_processing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
)
try:
chat.normalizer.register("zh", normalizer_zh_tn())
except ValueError as e:
logger.error(e)
except:
logger.warning("Package WeTextProcessing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
)
return ret
def reload_chat(coef: Optional[str]) -> str:
global is_in_generate
if is_in_generate:
gr.Warning("Cannot reload when generating!")
return coef
chat.unload()
gr.Info("Model unloaded.")
if len(coef) != 230:
gr.Warning("Ingore invalid DVAE coefficient.")
coef = None
try:
global custom_path
ret = load_chat(custom_path, coef)
except Exception as e:
raise gr.Error(str(e))
if not ret:
raise gr.Error("Unable to load model.")
gr.Info("Reload succeess.")
return chat.coef
def on_upload_sample_audio(sample_audio_input: Optional[str]) -> str:
if sample_audio_input is None:
return ""
sample_audio = load_audio(sample_audio_input, 24000)
spk_smp = chat.sample_audio_speaker(sample_audio)
del sample_audio
return spk_smp
def _set_generate_buttons(generate_button, interrupt_button, is_reset=False):
return gr.update(
value=generate_button, visible=is_reset, interactive=is_reset
), gr.update(value=interrupt_button, visible=not is_reset, interactive=not is_reset)
def refine_text(
text,
text_seed_input,
refine_text_flag,
temperature,
top_P,
top_K,
):
global chat
if not refine_text_flag:
sleep(1) # to skip fast answer of loading mark
return text
text = chat.infer(
text,
skip_refine_text=False,
refine_text_only=True,
params_refine_text=ChatTTS.Chat.RefineTextParams(
temperature=temperature,
top_P=top_P,
top_K=top_K,
manual_seed=text_seed_input,
),
)
return text[0] if isinstance(text, list) else text
def generate_audio(
text,
temperature,
top_P,
top_K,
spk_emb_text: str,
stream,
audio_seed_input,
sample_text_input,
sample_audio_code_input,
):
global chat, has_interrupted
if not text or has_interrupted or not spk_emb_text.startswith("蘁淰"):
return None
params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=spk_emb_text,
temperature=temperature,
top_P=top_P,
top_K=top_K,
manual_seed=audio_seed_input,
)
if sample_text_input and sample_audio_code_input:
params_infer_code.txt_smp = sample_text_input
params_infer_code.spk_smp = sample_audio_code_input
params_infer_code.spk_emb = None
wav = chat.infer(
text,
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
)
if stream:
for gen in wav:
audio = gen[0]
if audio is not None and len(audio) > 0:
yield 24000, float_to_int16(audio).T
del audio
else:
yield 24000, float_to_int16(wav[0]).T
def interrupt_generate():
global chat, has_interrupted
has_interrupted = True
chat.interrupt()
def set_buttons_before_generate(generate_button, interrupt_button):
global has_interrupted, is_in_generate
has_interrupted = False
is_in_generate = True
return _set_generate_buttons(
generate_button,
interrupt_button,
)
def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
global has_interrupted, is_in_generate
is_in_generate = False
return _set_generate_buttons(
generate_button,
interrupt_button,
audio_output is not None or has_interrupted,
)