adeel707 commited on
Commit
9d1f362
·
verified ·
1 Parent(s): e6614a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -109
app.py CHANGED
@@ -1,28 +1,64 @@
1
  import os
 
2
  import torch
3
  import whisper
4
  import streamlit as st
5
  from groq import Groq
6
- from TTS.api import TTS
7
  from dotenv import load_dotenv
8
  from tempfile import NamedTemporaryFile
9
  from streamlit_webrtc import webrtc_streamer, WebRtcMode, ClientSettings
10
  import av
11
  import numpy as np
12
- import scipy.io.wavfile
13
- import scipy.sparse
14
-
15
- from huggingface_hub import HfApi
16
-
17
- # will use api to restart space on a unrecoverable error
18
- api = HfApi(token=HF_TOKEN)
19
 
20
- # Load API key from Hugging Face
21
  load_dotenv()
22
  API_KEY = os.getenv("GROQ_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # LLM Response Function
25
  def get_llm_response(api_key, user_input):
 
 
 
26
  client = Groq(api_key=api_key)
27
  prompt = (
28
  "IMPORTANT: You are an AI assistant that MUST provide responses in 25 words or less.\n"
@@ -37,141 +73,230 @@ def get_llm_response(api_key, user_input):
37
  "Your response will be converted to speech. Maximum 25 words."
38
  )
39
 
40
- chat_completion = client.chat.completions.create(
41
- messages=[
42
- {"role": "system", "content": prompt},
43
- {"role": "user", "content": user_input}
44
- ],
45
- model="llama3-8b-8192",
46
- temperature=0.5,
47
- top_p=1,
48
- stream=False,
49
- )
50
- return chat_completion.choices[0].message.content
 
 
 
51
 
52
  # Transcribe Audio
53
  def transcribe_audio(audio_path, model_size="base"):
54
- model = whisper.load_model(model_size)
55
- result = model.transcribe(audio_path)
56
- return result["text"]
 
 
 
57
 
58
- # Generate Speech
59
- def generate_speech(text, output_file, speaker_wav, language="en", use_gpu=True):
60
  if not os.path.exists(speaker_wav):
61
  raise FileNotFoundError("Reference audio file not found. Please upload or record a valid audio.")
62
 
63
- tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=use_gpu)
64
- tts.tts_to_file(
65
- text=text,
66
- file_path=output_file,
67
- speaker_wav=speaker_wav,
68
- language=language,
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # Audio Frame Processing
72
  class AudioProcessor:
73
  def __init__(self):
74
  self.audio_frames = []
 
75
 
76
  def recv(self, frame):
77
- self.audio_frames.append(frame.to_ndarray().tobytes())
 
78
  return frame
79
 
80
  def save_audio(self, file_path):
81
- with open(file_path, "wb") as f:
82
- for frame in self.audio_frames:
83
- f.write(frame)
 
 
 
 
 
 
 
84
  return file_path
85
 
86
  # Streamlit App
87
  def main():
88
  st.set_page_config(page_title="Vocal AI", layout="wide")
89
- st.sidebar.title("Vocal-AI Settings")
90
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # User option for reference audio (Record or Upload)
92
  ref_audio_choice = st.sidebar.radio("Reference Audio", ("Upload", "Record"))
93
 
94
  ref_audio_path = None
95
  reference_audio_processor = None
96
 
97
- if ref_audio_choice == "Upload":
98
- reference_audio = st.sidebar.file_uploader("Upload Reference Audio", type=["wav", "mp3", "ogg"])
99
- if reference_audio:
100
- with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio:
101
- temp_ref_audio.write(reference_audio.read())
102
- ref_audio_path = temp_ref_audio.name
103
- else:
104
- st.sidebar.write("Record your reference audio:")
105
- reference_audio_processor = AudioProcessor()
106
- webrtc_streamer(
107
- key="ref_audio",
108
- mode=WebRtcMode.SENDRECV,
109
- client_settings=ClientSettings(rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}),
110
- audio_receiver_size=1024,
111
- video_processor_factory=None,
112
- audio_processor_factory=lambda: reference_audio_processor,
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- st.title("Welcome to VocaL AI")
116
- st.write("### How to Use")
117
- st.write("1. Upload or record a reference audio file.")
118
- st.write("2. Choose between text or audio input.")
119
- st.write("3. If audio input is selected, record and submit your audio.")
120
- st.write("4. Click 'Generate Speech' to hear the AI response in your cloned voice.")
121
-
122
- # User Input (Text or Audio)
123
- input_type = st.radio("Choose Input Type", ("Text", "Audio"))
124
- user_input = None
125
- user_audio_processor = None
126
-
127
- if input_type == "Text":
128
- user_input = st.text_area("Enter your text here")
129
- else:
130
- st.write("Record your voice:")
131
- user_audio_processor = AudioProcessor()
132
- webrtc_streamer(
133
- key="user_audio",
134
- mode=WebRtcMode.SENDRECV,
135
- client_settings=ClientSettings(rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}),
136
- audio_receiver_size=1024,
137
- video_processor_factory=None,
138
- audio_processor_factory=lambda: user_audio_processor,
139
- )
140
 
141
- if st.button("Generate Speech"):
142
- # Handle Reference Audio
143
- if reference_audio_processor:
144
- with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio:
145
- reference_audio_processor.save_audio(temp_ref_audio.name)
146
- ref_audio_path = temp_ref_audio.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
148
  if not ref_audio_path:
149
- st.error("Please upload or record reference audio.")
150
  return
151
 
152
- # Handle User Input
153
- if input_type == "Audio":
154
- if user_audio_processor:
155
- with NamedTemporaryFile(delete=False, suffix=".wav") as temp_user_audio:
156
- user_audio_processor.save_audio(temp_user_audio.name)
157
- user_input = transcribe_audio(temp_user_audio.name)
158
- os.unlink(temp_user_audio.name)
159
-
160
  if not user_input:
161
- st.error("Please enter text or record audio.")
162
  return
163
 
164
- # Get AI Response
165
- response_text = get_llm_response(API_KEY, user_input)
166
-
167
- # Generate Speech
168
- output_audio_path = "output_speech.wav"
169
- try:
170
- generate_speech(response_text, output_audio_path, ref_audio_path)
171
- os.unlink(ref_audio_path)
172
- st.audio(output_audio_path, format="audio/wav")
173
- except FileNotFoundError as e:
174
- st.error(str(e))
 
 
 
 
 
 
 
 
 
175
 
176
  if __name__ == "__main__":
177
- main()
 
1
  import os
2
+ import io
3
  import torch
4
  import whisper
5
  import streamlit as st
6
  from groq import Groq
 
7
  from dotenv import load_dotenv
8
  from tempfile import NamedTemporaryFile
9
  from streamlit_webrtc import webrtc_streamer, WebRtcMode, ClientSettings
10
  import av
11
  import numpy as np
12
+ import uuid
13
+ import time
 
 
 
 
 
14
 
15
+ # Load environment variables
16
  load_dotenv()
17
  API_KEY = os.getenv("GROQ_API_KEY")
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+
20
+ # By using XTTS you agree to CPML license
21
+ os.environ["COQUI_TOS_AGREED"] = "1"
22
+
23
+ # For proper language detection
24
+ import langid
25
+
26
+ # Import TTS components
27
+ from TTS.api import TTS
28
+ from TTS.tts.configs.xtts_config import XttsConfig
29
+ from TTS.tts.models.xtts import Xtts
30
+ from TTS.utils.generic_utils import get_user_data_dir
31
+
32
+ # Download and configure XTTS model
33
+ print("Downloading Coqui XTTS V2 if not already downloaded")
34
+ from TTS.utils.manage import ModelManager
35
+
36
+ model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
37
+ ModelManager().download_model(model_name)
38
+ model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
39
+ print("XTTS downloaded")
40
+
41
+ config = XttsConfig()
42
+ config.load_json(os.path.join(model_path, "config.json"))
43
+
44
+ model = Xtts.init_from_config(config)
45
+ model.load_checkpoint(
46
+ config,
47
+ checkpoint_path=os.path.join(model_path, "model.pth"),
48
+ vocab_path=os.path.join(model_path, "vocab.json"),
49
+ eval=True,
50
+ use_deepspeed=True,
51
+ )
52
+ if torch.cuda.is_available():
53
+ model.cuda()
54
+
55
+ supported_languages = config.languages
56
 
57
  # LLM Response Function
58
  def get_llm_response(api_key, user_input):
59
+ if not api_key:
60
+ return "API key not found. Please set the GROQ_API_KEY environment variable."
61
+
62
  client = Groq(api_key=api_key)
63
  prompt = (
64
  "IMPORTANT: You are an AI assistant that MUST provide responses in 25 words or less.\n"
 
73
  "Your response will be converted to speech. Maximum 25 words."
74
  )
75
 
76
+ try:
77
+ chat_completion = client.chat.completions.create(
78
+ messages=[
79
+ {"role": "system", "content": prompt},
80
+ {"role": "user", "content": user_input}
81
+ ],
82
+ model="llama3-8b-8192",
83
+ temperature=0.5,
84
+ top_p=1,
85
+ stream=False,
86
+ )
87
+ return chat_completion.choices[0].message.content
88
+ except Exception as e:
89
+ return f"Error with LLM: {str(e)}"
90
 
91
  # Transcribe Audio
92
  def transcribe_audio(audio_path, model_size="base"):
93
+ try:
94
+ model = whisper.load_model(model_size)
95
+ result = model.transcribe(audio_path)
96
+ return result["text"]
97
+ except Exception as e:
98
+ return f"Error transcribing audio: {str(e)}"
99
 
100
+ # Generate Speech using the configured XTTS model
101
+ def generate_speech(text, output_file, speaker_wav, language="en"):
102
  if not os.path.exists(speaker_wav):
103
  raise FileNotFoundError("Reference audio file not found. Please upload or record a valid audio.")
104
 
105
+ if language not in supported_languages:
106
+ st.warning(f"Language {language} is not supported. Defaulting to English.")
107
+ language = "en"
108
+
109
+ # Detect language if text is long enough
110
+ detected_lang = langid.classify(text)[0]
111
+ if detected_lang == "zh":
112
+ detected_lang = "zh-cn"
113
+
114
+ # Use the configured model directly
115
+ try:
116
+ t_latent = time.time()
117
+ gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
118
+ audio_path=speaker_wav,
119
+ gpt_cond_len=30,
120
+ gpt_cond_chunk_len=4,
121
+ max_ref_length=60
122
+ )
123
+
124
+ out = model.inference(
125
+ text,
126
+ language,
127
+ gpt_cond_latent,
128
+ speaker_embedding,
129
+ repetition_penalty=5.0,
130
+ temperature=0.75,
131
+ )
132
+
133
+ # Save the audio to file
134
+ torch.tensor(out["wav"]).unsqueeze(0).cpu().numpy()
135
+ import soundfile as sf
136
+ sf.write(output_file, out["wav"], 24000, 'PCM_24')
137
+
138
+ return True, "Speech generated successfully"
139
+ except Exception as e:
140
+ return False, f"Error generating speech: {str(e)}"
141
 
142
+ # Audio Frame Processing for WebRTC
143
  class AudioProcessor:
144
  def __init__(self):
145
  self.audio_frames = []
146
+ self.sample_rate = 24000 # XTTS expects 24kHz
147
 
148
  def recv(self, frame):
149
+ sound = frame.to_ndarray()
150
+ self.audio_frames.append(sound)
151
  return frame
152
 
153
  def save_audio(self, file_path):
154
+ if not self.audio_frames:
155
+ return None
156
+
157
+ # Concatenate audio frames
158
+ concat_audio = np.concatenate(self.audio_frames, axis=0)
159
+
160
+ # Save as WAV file
161
+ import soundfile as sf
162
+ sf.write(file_path, concat_audio, self.sample_rate)
163
+
164
  return file_path
165
 
166
  # Streamlit App
167
  def main():
168
  st.set_page_config(page_title="Vocal AI", layout="wide")
169
+
170
+ st.title("VocaL AI - Voice Cloning Assistant")
171
+ st.write("Clone your voice and interact with an AI assistant that responds in your voice!")
172
+
173
+ st.sidebar.title("Settings")
174
+
175
+ # Language selection
176
+ language = st.sidebar.selectbox(
177
+ "Output Language",
178
+ supported_languages,
179
+ index=supported_languages.index("en") if "en" in supported_languages else 0
180
+ )
181
+
182
+ # TOS agreement
183
+ agree_tos = st.sidebar.checkbox("I agree to the Coqui Public Model License (CPML)", value=False)
184
+
185
  # User option for reference audio (Record or Upload)
186
  ref_audio_choice = st.sidebar.radio("Reference Audio", ("Upload", "Record"))
187
 
188
  ref_audio_path = None
189
  reference_audio_processor = None
190
 
191
+ col1, col2 = st.columns(2)
192
+
193
+ with col1:
194
+ st.header("Step 1: Provide Reference Voice")
195
+ if ref_audio_choice == "Upload":
196
+ reference_audio = st.file_uploader("Upload Reference Audio", type=["wav", "mp3", "ogg"])
197
+ if reference_audio:
198
+ with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio:
199
+ temp_ref_audio.write(reference_audio.read())
200
+ ref_audio_path = temp_ref_audio.name
201
+ st.audio(ref_audio_path)
202
+ else:
203
+ st.write("Record your reference voice:")
204
+ reference_audio_processor = AudioProcessor()
205
+ webrtc_ctx = webrtc_streamer(
206
+ key="ref_audio",
207
+ mode=WebRtcMode.SENDRECV,
208
+ client_settings=ClientSettings(
209
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
210
+ media_stream_constraints={"audio": True, "video": False},
211
+ ),
212
+ audio_receiver_size=1024,
213
+ video_processor_factory=None,
214
+ audio_processor_factory=lambda: reference_audio_processor,
215
+ )
216
+
217
+ if webrtc_ctx.state.playing and reference_audio_processor is not None:
218
+ st.info("Recording... Speak into your microphone.")
219
+
220
+ if st.button("Save Reference Audio"):
221
+ if reference_audio_processor and reference_audio_processor.audio_frames:
222
+ with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio:
223
+ reference_audio_processor.save_audio(temp_ref_audio.name)
224
+ ref_audio_path = temp_ref_audio.name
225
+ st.success("Reference audio saved!")
226
+ st.audio(ref_audio_path)
227
+ else:
228
+ st.error("No audio recorded. Please speak into your microphone.")
229
 
230
+ with col2:
231
+ st.header("Step 2: Ask Something")
232
+ # User Input (Text or Audio)
233
+ input_type = st.radio("Choose Input Type", ("Text", "Audio"))
234
+ user_input = None
235
+ user_audio_processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ if input_type == "Text":
238
+ user_input = st.text_area("Enter your question or prompt here")
239
+ else:
240
+ st.write("Record your question:")
241
+ user_audio_processor = AudioProcessor()
242
+ webrtc_ctx_user = webrtc_streamer(
243
+ key="user_audio",
244
+ mode=WebRtcMode.SENDRECV,
245
+ client_settings=ClientSettings(
246
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
247
+ media_stream_constraints={"audio": True, "video": False},
248
+ ),
249
+ audio_receiver_size=1024,
250
+ video_processor_factory=None,
251
+ audio_processor_factory=lambda: user_audio_processor,
252
+ )
253
+
254
+ if webrtc_ctx_user.state.playing and user_audio_processor is not None:
255
+ st.info("Recording... Ask your question")
256
+
257
+ if st.button("Process Recording"):
258
+ if user_audio_processor and user_audio_processor.audio_frames:
259
+ with NamedTemporaryFile(delete=False, suffix=".wav") as temp_user_audio:
260
+ user_audio_processor.save_audio(temp_user_audio.name)
261
+ user_input = transcribe_audio(temp_user_audio.name)
262
+ st.write(f"Transcribed: {user_input}")
263
+ else:
264
+ st.error("No audio recorded. Please speak into your microphone.")
265
 
266
+ # Process and generate response
267
+ if st.button("Generate AI Response in My Voice"):
268
+ if not agree_tos:
269
+ st.error("Please agree to the Coqui Public Model License to continue.")
270
+ return
271
+
272
  if not ref_audio_path:
273
+ st.error("Please provide reference audio (upload or record).")
274
  return
275
 
 
 
 
 
 
 
 
 
276
  if not user_input:
277
+ st.error("Please enter text or record a question.")
278
  return
279
 
280
+ with st.spinner("Processing..."):
281
+ # Get AI Response
282
+ llm_response = get_llm_response(API_KEY, user_input)
283
+ st.subheader("AI Response:")
284
+ st.write(llm_response)
285
+
286
+ # Generate Speech
287
+ output_audio_path = f"output_speech_{uuid.uuid4()}.wav"
288
+ success, message = generate_speech(
289
+ llm_response,
290
+ output_audio_path,
291
+ ref_audio_path,
292
+ language
293
+ )
294
+
295
+ if success:
296
+ st.subheader("Listen to the response in your voice:")
297
+ st.audio(output_audio_path, format="audio/wav")
298
+ else:
299
+ st.error(message)
300
 
301
  if __name__ == "__main__":
302
+ main()