Spaces:
Sleeping
Sleeping
""" | |
Copyright 2022 Balacoon | |
TTS interactive demo | |
""" | |
import os | |
import glob | |
import logging | |
from typing import cast | |
from threading import Lock | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
from balacoon_tts import TTS | |
from huggingface_hub import hf_hub_download, list_repo_files | |
import torch | |
from conversation import get_default_conv_template | |
# locker that disallow access to the tts object from more then one thread | |
locker = Lock() | |
# global tts module, initialized from a model selected | |
tts = None | |
# path to the model that is currently used in tts | |
cur_model_path = None | |
# cache of speakers, maps model name to speaker list | |
model_to_speakers = dict() | |
model_repo_dir = "/data" | |
for name in list_repo_files(repo_id="balacoon/tts"): | |
if not os.path.isfile(os.path.join(model_repo_dir, name)): | |
hf_hub_download( | |
repo_id="balacoon/tts", | |
filename=name, | |
local_dir=model_repo_dir, | |
) | |
stt_pipe = pipeline( | |
task="automatic-speech-recognition", | |
model="openai/whisper-large-v3", | |
) | |
talkers = { | |
"m3b": { | |
"tokenizer": AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False), | |
"model": AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", device_map="auto"), | |
"conv": get_default_conv_template("minichat") | |
} | |
} | |
def transcribe_stt(audio): | |
if audio is None: | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
text = stt_pipe(audio, generate_kwargs={"language": "english", "task": "transcribe"})["text"] | |
return text | |
def m3b_talk(text): | |
m3bconv = talkers["m3b"]["conv"] | |
m3bconv.append_message(m3bconv.roles[0], text) | |
m3bconv.append_message(m3bconv.roles[1], None) | |
input_ids = talkers["m3b"]["tokenizer"]([text]).input_ids | |
response_tokens = talkers["m3b"]["model"]( | |
torch.as_tensor(m3bconv.get_prompt()), | |
do_sample=True, | |
temperature=0.2, | |
max_new_tokens=1024, | |
) | |
response_tokens = response_tokens[0][len(input_ids[0]):] | |
response = talkers["m3b"]["tokenizer"].decode(response_tokens, skip_special_tokens=True).strip() | |
return response | |
def main(): | |
logging.basicConfig(level=logging.INFO) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
<h1 align="center">Balacoon🦝 Text-to-Speech</h1> | |
1. Write an utterance to generate, | |
2. Select the model to synthesize with | |
3. Select speaker | |
4. Hit "Generate" and listen to the result! | |
You can learn more about models available | |
[here](https://huggingface.co/balacoon/tts). | |
Visit [Balacoon website](https://balacoon.com/) for more info. | |
""" | |
) | |
with gr.Row(variant="panel"): | |
text = gr.Textbox(label="Text", placeholder="Type something here...") | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
repo_files = os.listdir(model_repo_dir) | |
model_files = [x for x in repo_files if x.endswith("_cpu.addon")] | |
model_name = gr.Dropdown( | |
label="Model", | |
choices=model_files, | |
) | |
with gr.Column(variant="panel"): | |
speaker = gr.Dropdown(label="Speaker", choices=[]) | |
def set_model(model_name_str: str): | |
""" | |
gets value from `model_name`. either | |
uses cached list of speakers for the given model name | |
or loads the addon and checks what are the speakers. | |
""" | |
global model_to_speakers | |
if model_name_str in model_to_speakers: | |
speakers = model_to_speakers[model_name_str] | |
else: | |
global tts, cur_model_path, locker | |
with locker: | |
# need to load this model to learn the list of speakers | |
model_path = os.path.join(model_repo_dir, model_name_str) | |
if tts is not None: | |
del tts | |
tts = TTS(model_path) | |
cur_model_path = model_path | |
speakers = tts.get_speakers() | |
model_to_speakers[model_name_str] = speakers | |
value = speakers[-1] | |
return gr.Dropdown.update( | |
choices=speakers, value=value, visible=True | |
) | |
model_name.change(set_model, inputs=model_name, outputs=speaker) | |
with gr.Row(variant="panel"): | |
generate = gr.Button("Generate") | |
with gr.Row(variant="panel"): | |
audio = gr.Audio() | |
with gr.Row(variant="panel"): | |
gr.Markdown("## Transcribe\n\nTranscribe audio to text.") | |
with gr.Row(variant="panel"): | |
with gr.Column(variant="panel"): | |
stt_input_mic = gr.Audio(source="microphone", type="filepath", label="Record") | |
stt_input_file = gr.Audio(source="upload", type="filepath", label="Upload") | |
with gr.Column(variant="panel"): | |
stt_transcribe_output = gr.Textbox() | |
stt_transcribe_btn = gr.Button("Transcribe") | |
with gr.Row(variant="panel"): | |
gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.") | |
with gr.Row(variant="panel"): | |
with gr.Column(variant="panel"): | |
m3b_talk_input = gr.Textbox(label="Message", placeholder="Type something here...") | |
with gr.Column(variant="panel"): | |
m3b_talk_output = gr.Textbox() | |
m3b_talk_btn = gr.Button("Send") | |
def synthesize_audio(text_str: str, model_name_str: str, speaker_str: str): | |
""" | |
gets utterance to synthesize from `text` Textbox | |
and speaker name from `speaker` dropdown list. | |
speaker name might be empty for single-speaker models. | |
Synthesizes the waveform and updates `audio` with it. | |
""" | |
if not text_str or not model_name_str or not speaker_str: | |
logging.info("text, model name or speaker are not provided") | |
return None | |
expected_model_path = os.path.join(model_repo_dir, model_name_str) | |
global tts, cur_model_path, locker | |
with locker: | |
if expected_model_path != cur_model_path: | |
# reload model | |
if tts is not None: | |
del tts | |
tts = TTS(expected_model_path) | |
cur_model_path = expected_model_path | |
if len(text_str) > 1024: | |
# truncate the text | |
text_str = text_str[:1024] | |
samples = tts.synthesize(text_str, speaker_str) | |
return gr.Audio.update(value=(tts.get_sampling_rate(), samples)) | |
generate.click(synthesize_audio, inputs=[text, model_name, speaker], outputs=audio, api_name="synthesize") | |
stt_transcribe_btn.click(transcribe_stt, inputs=stt_input_file, outputs=stt_transcribe_output, api_name="transcribe") | |
m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b") | |
demo.queue(concurrency_count=1).launch() | |
if __name__ == "__main__": | |
main() | |