EPark25 commited on
Commit
1e937ce
·
1 Parent(s): d2e67c6

unreliably working

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -1,10 +1,15 @@
1
  import gradio as gr
 
2
  from huggingface_hub import InferenceClient
3
- from transformers import pipeline
4
- from scipy.io.wavfile import write as write_wav
5
 
6
- AUDIO_FILE_PATH = "bark_generation.wav"
7
- synthesizer = pipeline("text-to-speech", "suno/bark-small")
 
 
 
 
8
 
9
  """
10
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
@@ -18,10 +23,6 @@ with gr.Blocks() as demo:
18
  msg = gr.Textbox(submit_btn=True)
19
  clear = gr.Button("Clear")
20
 
21
- def synthesize_audio(text):
22
- speech = synthesizer(text, forward_params={"do_sample": True})
23
- write_wav(AUDIO_FILE_PATH, rate=speech["sampling_rate"], data=speech["audio"])
24
-
25
  def user(user_message, history: list):
26
  return "", history + [{"role": "user", "content": user_message}]
27
 
@@ -33,14 +34,20 @@ with gr.Blocks() as demo:
33
  ):
34
  token = message.choices[0].delta.content
35
  history[-1]["content"] += token
36
- yield history, None
 
 
37
 
38
- synthesize_audio(history[-1]["content"])
39
- return history, AUDIO_FILE_PATH
 
 
 
 
40
 
41
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
42
- bot, chatbot, [chatbot, audio_box]
43
- )
44
  clear.click(lambda: None, None, chatbot, queue=False)
45
 
46
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import torch
3
  from huggingface_hub import InferenceClient
4
+ from transformers import BarkModel
5
+ from transformers import AutoProcessor
6
 
7
+
8
+ model = BarkModel.from_pretrained("suno/bark-small")
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+ model = model.to(device)
11
+
12
+ processor = AutoProcessor.from_pretrained("suno/bark")
13
 
14
  """
15
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
23
  msg = gr.Textbox(submit_btn=True)
24
  clear = gr.Button("Clear")
25
 
 
 
 
 
26
  def user(user_message, history: list):
27
  return "", history + [{"role": "user", "content": user_message}]
28
 
 
34
  ):
35
  token = message.choices[0].delta.content
36
  history[-1]["content"] += token
37
+ yield history
38
+
39
+ return history
40
 
41
+ def read(history: list):
42
+ text = history[-1]["content"]
43
+ inputs = processor(text=text, return_tensors="pt").to(device)
44
+ speech = model.generate(**inputs.to(device))
45
+ sampling_rate = model.generation_config.sample_rate
46
+ return tuple((sampling_rate, speech.cpu().numpy().squeeze()))
47
 
48
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
49
+ bot, chatbot, chatbot
50
+ ).then(read, chatbot, audio_box)
51
  clear.click(lambda: None, None, chatbot, queue=False)
52
 
53
  if __name__ == "__main__":