freddyaboulton HF Staff commited on
Commit
7ee8f06
·
1 Parent(s): 9666de8
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import io
2
- import math
3
  from threading import Thread
4
  import random
 
5
 
6
  import numpy as np
7
  import spaces
@@ -26,7 +26,7 @@ model = ParlerTTSForConditionalGeneration.from_pretrained(
26
  jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
27
  ).to(device)
28
 
29
- client = InferenceClient()
30
 
31
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
32
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
@@ -63,10 +63,9 @@ def numpy_to_mp3(audio_array, sampling_rate):
63
  sampling_rate = model.audio_encoder.config.sampling_rate
64
  frame_rate = model.audio_encoder.config.frame_rate
65
 
66
- @spaces.GPU
67
- def generate_base(audio):
68
 
69
- question = client.audtomatic_speech_recognition(audio)
 
70
 
71
  messages = [{"role": "sytem", "content": ("You are a magic 8 ball."
72
  "Someone will present to you a situation or question and your job "
@@ -74,9 +73,13 @@ def generate_base(audio):
74
  "'curiosity killed the cat' or 'The early bird gets the worm'.")},
75
  {"role": "user", "content": f"Please tell me what to do about {question}"}]
76
 
77
- response = client.chat_completion(messages, max_tokens=1024, seed=random.randint(1, 5000))
78
  response = response.choices[0].message.content
 
 
79
 
 
 
80
 
81
  play_steps_in_s = 1.0
82
  play_steps = int(frame_rate * play_steps_in_s)
@@ -85,7 +88,7 @@ def generate_base(audio):
85
  description_tokens = tokenizer(description, return_tensors="pt").to(device)
86
 
87
  streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
88
- prompt = tokenizer(sentence, return_tensors="pt").to(device)
89
 
90
  generation_kwargs = dict(
91
  input_ids=description_tokens.input_ids,
@@ -102,11 +105,12 @@ def generate_base(audio):
102
 
103
  for new_audio in streamer:
104
  print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
105
- yield story, numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
106
 
107
  css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
108
  .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
109
 
 
110
  with gr.Blocks() as block:
111
  gr.HTML(
112
  f"""
@@ -117,10 +121,12 @@ with gr.Blocks() as block:
117
  with gr.Group():
118
  with gr.Row():
119
  audio_out = gr.Audio(visible=False, streaming=True)
120
- answer = gr.Textbox(label="Answer")
 
 
121
  with gr.Row():
122
- audio_in = gr.Audio(label="Speak you question", sources="microphone", format="filepath")
123
 
124
- audio_in.stop_recording(fn=generate_base, inputs=audio_in, outputs=[answer, audio_out])
125
 
126
  block.launch()
 
1
  import io
 
2
  from threading import Thread
3
  import random
4
+ import os
5
 
6
  import numpy as np
7
  import spaces
 
26
  jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
27
  ).to(device)
28
 
29
+ client = InferenceClient(token=os.getenv("HF_TOKEN"))
30
 
31
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
32
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
 
63
  sampling_rate = model.audio_encoder.config.sampling_rate
64
  frame_rate = model.audio_encoder.config.frame_rate
65
 
 
 
66
 
67
+ def generate_response(audio):
68
+ question = client.automatic_speech_recognition(audio)
69
 
70
  messages = [{"role": "sytem", "content": ("You are a magic 8 ball."
71
  "Someone will present to you a situation or question and your job "
 
73
  "'curiosity killed the cat' or 'The early bird gets the worm'.")},
74
  {"role": "user", "content": f"Please tell me what to do about {question}"}]
75
 
76
+ response = client.chat_completion(messages, max_tokens=128, seed=random.randint(1, 5000))
77
  response = response.choices[0].message.content
78
+ yield response, gr.Textbox(visible=True)
79
+
80
 
81
+ @spaces.GPU
82
+ def generate_base(answer):
83
 
84
  play_steps_in_s = 1.0
85
  play_steps = int(frame_rate * play_steps_in_s)
 
88
  description_tokens = tokenizer(description, return_tensors="pt").to(device)
89
 
90
  streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
91
+ prompt = tokenizer(answer, return_tensors="pt").to(device)
92
 
93
  generation_kwargs = dict(
94
  input_ids=description_tokens.input_ids,
 
105
 
106
  for new_audio in streamer:
107
  print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
108
+ yield gr.Textbox(value=answer, visible=True), numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
109
 
110
  css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
111
  .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
112
 
113
+
114
  with gr.Blocks() as block:
115
  gr.HTML(
116
  f"""
 
121
  with gr.Group():
122
  with gr.Row():
123
  audio_out = gr.Audio(visible=False, streaming=True)
124
+ answer = gr.Textbox(visible=False, label="Answer")
125
+ state = gr.State()
126
+
127
  with gr.Row():
128
+ audio_in = gr.Audio(label="Speak you question", sources="microphone", type="filepath")
129
 
130
+ audio_in.stop_recording(generate_response, audio_in, [state, answer]).then(fn=generate_base, inputs=state, outputs=[answer, audio_out])
131
 
132
  block.launch()