Samuel L Meyers commited on
Commit
d487976
·
1 Parent(s): 06ae9a8
Files changed (1) hide show
  1. app.py +36 -1
app.py CHANGED
@@ -9,11 +9,14 @@ import glob
9
  import logging
10
  from typing import cast
11
  from threading import Lock
12
- from transformers import pipeline
13
 
14
  import gradio as gr
15
  from balacoon_tts import TTS
16
  from huggingface_hub import hf_hub_download, list_repo_files
 
 
 
17
 
18
  # locker that disallow access to the tts object from more then one thread
19
  locker = Lock()
@@ -37,12 +40,35 @@ stt_pipe = pipeline(
37
  model="openai/whisper-large-v3",
38
  )
39
 
 
 
 
 
 
 
 
 
40
  def transcribe_stt(audio):
41
  if audio is None:
42
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
43
  text = stt_pipe(audio, generate_kwargs={"language": "english", "task": "transcribe"})["text"]
44
  return text
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def main():
47
  logging.basicConfig(level=logging.INFO)
48
 
@@ -116,6 +142,14 @@ def main():
116
  with gr.Column(variant="panel"):
117
  stt_transcribe_output = gr.Textbox()
118
  stt_transcribe_btn = gr.Button("Transcribe")
 
 
 
 
 
 
 
 
119
 
120
  def synthesize_audio(text_str: str, model_name_str: str, speaker_str: str):
121
  """
@@ -144,6 +178,7 @@ def main():
144
 
145
  generate.click(synthesize_audio, inputs=[text, model_name, speaker], outputs=audio, api_name="synthesize")
146
  stt_transcribe_btn.click(transcribe_stt, inputs=stt_input_file, outputs=stt_transcribe_output, api_name="transcribe")
 
147
 
148
  demo.queue(concurrency_count=1).launch()
149
 
 
9
  import logging
10
  from typing import cast
11
  from threading import Lock
12
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
13
 
14
  import gradio as gr
15
  from balacoon_tts import TTS
16
  from huggingface_hub import hf_hub_download, list_repo_files
17
+ import torch
18
+
19
+ from conversation import get_default_conv_template
20
 
21
  # locker that disallow access to the tts object from more then one thread
22
  locker = Lock()
 
40
  model="openai/whisper-large-v3",
41
  )
42
 
43
+ talkers = {
44
+ "m3b": {
45
+ "tokenizer": AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False),
46
+ "model": AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", device_map="auto"),
47
+ "conv": get_default_conv_template("minichat")
48
+ }
49
+ }
50
+
51
  def transcribe_stt(audio):
52
  if audio is None:
53
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
54
  text = stt_pipe(audio, generate_kwargs={"language": "english", "task": "transcribe"})["text"]
55
  return text
56
 
57
+ def m3b_talk(text):
58
+ m3bconv = talkers["m3b"]["conv"]
59
+ m3bconv.append_message(m3bconv.roles[0], text)
60
+ m3bconv.append_message(m3bconv.roles[1], None)
61
+ input_ids = talkers["m3b"]["tokenizer"]([text]).input_ids
62
+ response_tokens = talkers["m3b"]["model"](
63
+ torch.as_tensor(m3bconv.get_prompt()),
64
+ do_sample=True,
65
+ temperature=0.2,
66
+ max_new_tokens=1024,
67
+ )
68
+ response_tokens = response_tokens[0][len(input_ids[0]):]
69
+ response = talkers["m3b"]["tokenizer"].decode(response_tokens, skip_special_tokens=True).strip()
70
+ return response
71
+
72
  def main():
73
  logging.basicConfig(level=logging.INFO)
74
 
 
142
  with gr.Column(variant="panel"):
143
  stt_transcribe_output = gr.Textbox()
144
  stt_transcribe_btn = gr.Button("Transcribe")
145
+ with gr.Row(variant="panel"):
146
+ gr.Markdown("## Talk to MiniChat-3B\n\nTalk to MiniChat-3B.")
147
+ with gr.Row(variant="panel"):
148
+ with gr.Column(variant="panel"):
149
+ m3b_talk_input = gr.Textbox(label="Message", placeholder="Type something here...")
150
+ with gr.Column(variant="panel"):
151
+ m3b_talk_output = gr.Textbox()
152
+ m3b_talk_btn = gr.Button("Send")
153
 
154
  def synthesize_audio(text_str: str, model_name_str: str, speaker_str: str):
155
  """
 
178
 
179
  generate.click(synthesize_audio, inputs=[text, model_name, speaker], outputs=audio, api_name="synthesize")
180
  stt_transcribe_btn.click(transcribe_stt, inputs=stt_input_file, outputs=stt_transcribe_output, api_name="transcribe")
181
+ m3b_talk_btn.click(m3b_talk, inputs=m3b_talk_input, outputs=m3b_talk_output, api_name="talk_m3b")
182
 
183
  demo.queue(concurrency_count=1).launch()
184