audio-webui / webui /ui /tabs /text_to_speech.py
mrtroydev's picture
Upload folder using huggingface_hub
3883c60 verified
import gradio
from gradio.components import IOComponent
import webui.modules.models as mod
import webui.modules.implementations.ttsmodels as tts_models
from webui.modules import util
mod_type = 'text-to-speech'
loader: mod.TTSModelLoader = mod.TTSModelLoader
to_rvc, audio_out = None, None
def get_models_installed():
# return [model for model in mod.get_installed_models(mod_type) if model in [tts.replace('/', '--') for tts in mod.all_tts_models()]]
return [model for model in mod.get_installed_models(mod_type) if
model in [tts.replace('/', '--') for tts in mod.all_tts_models()]] + \
[model.model for model in mod.all_tts() if model.no_install]
def filter_components(components):
return [component for component in components if isinstance(component, IOComponent)]
def text_to_speech():
with gradio.Row():
with gradio.Column():
all_components_dict = tts_models.all_elements_dict()
all_components = tts_models.all_elements(all_components_dict)
with gradio.Row():
selected = gradio.Dropdown(get_models_installed(), label='Model')
with gradio.Column(elem_classes='smallsplit'):
refresh = gradio.Button('πŸ”ƒ', variant='tool secondary')
unload = gradio.Button('πŸ’£', variant='tool primary')
refresh.click(fn=get_models_installed, outputs=selected, show_progress=True)
def unload_model():
global loader
if isinstance(loader, mod.TTSModelLoader):
loader.unload_model()
return [gradio.update(value='')] + [gradio.update(visible=False) for _ in all_components]
unload.click(fn=unload_model, outputs=[selected] + all_components, show_progress=True)
def load_model(model, progress=gradio.Progress()):
progress(0, desc='Loading model.')
global loader
if not (hasattr(loader, 'model') and model.lower().endswith(loader.model.lower())):
unload_model()
loader = loader.from_model(model)
loader.load_model(progress)
inputs = all_components_dict[loader.model]
return_value = [gradio.update()] + [
gradio.update(visible=element in inputs and not (hasattr(element, 'hide') and element.hide)) for
element in all_components]
progress(1, desc='Loaded model!')
return return_value
selected.select(fn=load_model, inputs=selected, outputs=[selected] + all_components, show_progress=True)
with gradio.Column():
global to_rvc, audio_out
with gradio.Row():
generate = gradio.Button('Generate', variant='primary', elem_id='tts-generate')
to_rvc = gradio.Button('Send to RVC', elem_id='tts-to-rvc')
with gradio.Row():
audio_out = gradio.Audio(interactive=False)
with gradio.Row():
video_out = gradio.Video()
with gradio.Row():
file_out = gradio.File()
filtered_components = filter_components(all_components)
def _generate(progress=gradio.Progress(track_tqdm=True), *values):
progress(0, desc='Generating audio...')
global loader
inputs = filtered_components
inputs = [values[i] for i in range(len(inputs)) if
inputs[i] in all_components_dict[loader.model]] # Filter and convert inputs
response, file = loader.get_response(*inputs, progress=progress)
return response, util.make_waveform(response), file
generate.click(fn=_generate, inputs=filtered_components,
outputs=[audio_out, video_out, file_out], show_progress=True)
# def _generate(inputs, values, progress=gradio.Progress()):
# global loader
# inputs = [values[i] for i in range(len(inputs)) if
# inputs[i] in all_components_dict[loader.model]] # Filter and convert inputs
# response, file = loader.get_response(*inputs, progress=progress)
# return response, util.make_waveform(response), file
#
# filtered_components = filter_components(all_components)
# generate.click(fn=lambda *values: _generate(filtered_components, values), inputs=filtered_components,
# outputs=[audio_out, video_out, file_out], show_progress=True)