Samuel L Meyers commited on
Commit
da8a172
·
1 Parent(s): e71462a

Inital MiniChat test

Browse files
Files changed (1) hide show
  1. app.py +2 -137
app.py CHANGED
@@ -1,44 +1,11 @@
1
- """
2
- Copyright 2022 Balacoon
3
-
4
- TTS interactive demo
5
- """
6
-
7
- import os
8
- import glob
9
  import logging
10
  from typing import cast
11
  from threading import Lock
12
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
13
-
14
- import gradio as gr
15
- from balacoon_tts import TTS
16
- from huggingface_hub import hf_hub_download, list_repo_files
17
  import torch
18
 
19
  from conversation import get_default_conv_template
20
-
21
- # locker that disallow access to the tts object from more then one thread
22
- locker = Lock()
23
- # global tts module, initialized from a model selected
24
- tts = None
25
- # path to the model that is currently used in tts
26
- cur_model_path = None
27
- # cache of speakers, maps model name to speaker list
28
- model_to_speakers = dict()
29
- model_repo_dir = "/data"
30
- for name in list_repo_files(repo_id="balacoon/tts"):
31
- if not os.path.isfile(os.path.join(model_repo_dir, name)):
32
- hf_hub_download(
33
- repo_id="balacoon/tts",
34
- filename=name,
35
- local_dir=model_repo_dir,
36
- )
37
-
38
- stt_pipe = pipeline(
39
- task="automatic-speech-recognition",
40
- model="openai/whisper-large-v3",
41
- )
42
 
43
  talkers = {
44
  "m3b": {
@@ -48,12 +15,6 @@ talkers = {
48
  }
49
  }
50
 
51
- def transcribe_stt(audio):
52
- if audio is None:
53
- raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
54
- text = stt_pipe(audio, generate_kwargs={"language": "english", "task": "transcribe"})["text"]
55
- return text
56
-
57
  def m3b_talk(text):
58
  m3bconv = talkers["m3b"]["conv"]
59
  m3bconv.append_message(m3bconv.roles[0], text)
@@ -73,75 +34,6 @@ def main():
73
  logging.basicConfig(level=logging.INFO)
74
 
75
  with gr.Blocks() as demo:
76
- gr.Markdown(
77
- """
78
- <h1 align="center">Balacoon🦝 Text-to-Speech</h1>
79
-
80
- 1. Write an utterance to generate,
81
- 2. Select the model to synthesize with
82
- 3. Select speaker
83
- 4. Hit "Generate" and listen to the result!
84
-
85
- You can learn more about models available
86
- [here](https://huggingface.co/balacoon/tts).
87
- Visit [Balacoon website](https://balacoon.com/) for more info.
88
- """
89
- )
90
- with gr.Row(variant="panel"):
91
- text = gr.Textbox(label="Text", placeholder="Type something here...")
92
-
93
- with gr.Row():
94
- with gr.Column(variant="panel"):
95
- repo_files = os.listdir(model_repo_dir)
96
- model_files = [x for x in repo_files if x.endswith("_cpu.addon")]
97
- model_name = gr.Dropdown(
98
- label="Model",
99
- choices=model_files,
100
- )
101
- with gr.Column(variant="panel"):
102
- speaker = gr.Dropdown(label="Speaker", choices=[])
103
-
104
- def set_model(model_name_str: str):
105
- """
106
- gets value from `model_name`. either
107
- uses cached list of speakers for the given model name
108
- or loads the addon and checks what are the speakers.
109
- """
110
- global model_to_speakers
111
- if model_name_str in model_to_speakers:
112
- speakers = model_to_speakers[model_name_str]
113
- else:
114
- global tts, cur_model_path, locker
115
- with locker:
116
- # need to load this model to learn the list of speakers
117
- model_path = os.path.join(model_repo_dir, model_name_str)
118
- if tts is not None:
119
- del tts
120
- tts = TTS(model_path)
121
- cur_model_path = model_path
122
- speakers = tts.get_speakers()
123
- model_to_speakers[model_name_str] = speakers
124
-
125
- value = speakers[-1]
126
- return gr.Dropdown.update(
127
- choices=speakers, value=value, visible=True
128
- )
129
-
130
- model_name.change(set_model, inputs=model_name, outputs=speaker)
131
-
132
- with gr.Row(variant="panel"):
133
- generate = gr.Button("Generate")
134
- with gr.Row(variant="panel"):
135
- audio = gr.Audio()
136
- with gr.Row(variant="panel"):
137
- gr.Markdown("## Transcribe\n\nTranscribe audio to text.")
138
- with gr.Row(variant="panel"):
139
- with gr.Column(variant="panel"):
140
- stt_input_mic = gr.Audio(source="microphone", type="filepath", label="Record")
141
- stt_input_file = gr.Audio(source="upload", type="filepath", label="Upload")
142
- with gr.Column(variant="panel"):
143
- stt_transcribe_output = gr.Textbox()
144
- stt_transcribe_btn = gr.Button("Transcribe")
145
  with gr.Row(variant="panel"):
146
  gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.")
147
  with gr.Row(variant="panel"):
@@ -151,33 +43,6 @@ def main():
151
  m3b_talk_output = gr.Textbox()
152
  m3b_talk_btn = gr.Button("Send")
153
 
154
- def synthesize_audio(text_str: str, model_name_str: str, speaker_str: str):
155
- """
156
- gets utterance to synthesize from `text` Textbox
157
- and speaker name from `speaker` dropdown list.
158
- speaker name might be empty for single-speaker models.
159
- Synthesizes the waveform and updates `audio` with it.
160
- """
161
- if not text_str or not model_name_str or not speaker_str:
162
- logging.info("text, model name or speaker are not provided")
163
- return None
164
- expected_model_path = os.path.join(model_repo_dir, model_name_str)
165
- global tts, cur_model_path, locker
166
- with locker:
167
- if expected_model_path != cur_model_path:
168
- # reload model
169
- if tts is not None:
170
- del tts
171
- tts = TTS(expected_model_path)
172
- cur_model_path = expected_model_path
173
- if len(text_str) > 1024:
174
- # truncate the text
175
- text_str = text_str[:1024]
176
- samples = tts.synthesize(text_str, speaker_str)
177
- return gr.Audio.update(value=(tts.get_sampling_rate(), samples))
178
-
179
- generate.click(synthesize_audio, inputs=[text, model_name, speaker], outputs=audio, api_name="synthesize")
180
- stt_transcribe_btn.click(transcribe_stt, inputs=stt_input_file, outputs=stt_transcribe_output, api_name="transcribe")
181
  m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b")
182
 
183
  demo.queue(concurrency_count=1).launch()
 
 
 
 
 
 
 
 
 
1
  import logging
2
  from typing import cast
3
  from threading import Lock
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
5
  import torch
6
 
7
  from conversation import get_default_conv_template
8
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  talkers = {
11
  "m3b": {
 
15
  }
16
  }
17
 
 
 
 
 
 
 
18
  def m3b_talk(text):
19
  m3bconv = talkers["m3b"]["conv"]
20
  m3bconv.append_message(m3bconv.roles[0], text)
 
34
  logging.basicConfig(level=logging.INFO)
35
 
36
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  with gr.Row(variant="panel"):
38
  gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.")
39
  with gr.Row(variant="panel"):
 
43
  m3b_talk_output = gr.Textbox()
44
  m3b_talk_btn = gr.Button("Send")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b")
47
 
48
  demo.queue(concurrency_count=1).launch()