adeel707 commited on
Commit
ddf4b47
·
verified ·
1 Parent(s): 266582d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import whisper
4
+ import streamlit as st
5
+ from groq import Groq
6
+ from TTS.api import TTS
7
+ from tempfile import NamedTemporaryFile
8
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, ClientSettings
9
+ import av
10
+
11
+ # LLM Response Function
12
+ def get_llm_response(api_key, user_input):
13
+ client = Groq(api_key=api_key)
14
+ prompt = (
15
+ "IMPORTANT: You are an AI assistant that MUST provide responses in 25 words or less.\n"
16
+ "CRITICAL RULES:\n"
17
+ "1. NEVER exceed 25 words unless absolutely necessary.\n"
18
+ "2. Always give a complete sentence with full context.\n"
19
+ "3. Answer directly and precisely.\n"
20
+ "4. Use clear, simple language.\n"
21
+ "5. Maintain a polite, professional tone.\n"
22
+ "6. NO lists, bullet points, or multiple paragraphs.\n"
23
+ "7. NEVER apologize for brevity - embrace it.\n"
24
+ "Your response will be converted to speech. Maximum 25 words."
25
+ )
26
+
27
+ chat_completion = client.chat.completions.create(
28
+ messages=[
29
+ {"role": "system", "content": prompt},
30
+ {"role": "user", "content": user_input}
31
+ ],
32
+ model="llama3-8b-8192",
33
+ temperature=0.5,
34
+ top_p=1,
35
+ stream=False,
36
+ )
37
+ return chat_completion.choices[0].message.content
38
+
39
+ # Transcribe Audio
40
+ def transcribe_audio(audio_path, model_size="base"):
41
+ model = whisper.load_model(model_size)
42
+ result = model.transcribe(audio_path)
43
+ return result["text"]
44
+
45
+ # Generate Speech
46
+ def generate_speech(text, output_file, speaker_wav, language="en", use_gpu=True):
47
+ if not os.path.exists(speaker_wav):
48
+ raise FileNotFoundError("Reference audio file not found. Please upload or record a valid audio.")
49
+
50
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=use_gpu)
51
+ tts.tts_to_file(
52
+ text=text,
53
+ file_path=output_file,
54
+ speaker_wav=speaker_wav,
55
+ language=language,
56
+ )
57
+
58
+ # Audio Frame Processing
59
+ class AudioProcessor:
60
+ def __init__(self):
61
+ self.audio_frames = []
62
+
63
+ def recv(self, frame):
64
+ self.audio_frames.append(frame.to_ndarray().tobytes())
65
+ return frame
66
+
67
+ def save_audio(self, file_path):
68
+ with open(file_path, "wb") as f:
69
+ for frame in self.audio_frames:
70
+ f.write(frame)
71
+ return file_path
72
+
73
+ # Streamlit App
74
+ def main():
75
+ st.set_page_config(page_title="Vocal AI", layout="wide")
76
+ st.sidebar.title("Vocal-AI Settings")
77
+
78
+ # User option for reference audio (Record or Upload)
79
+ ref_audio_choice = st.sidebar.radio("Reference Audio", ("Upload", "Record"))
80
+
81
+ ref_audio_path = None
82
+ reference_audio_processor = None
83
+
84
+ if ref_audio_choice == "Upload":
85
+ reference_audio = st.sidebar.file_uploader("Upload Reference Audio", type=["wav", "mp3", "ogg"])
86
+ if reference_audio:
87
+ with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio:
88
+ temp_ref_audio.write(reference_audio.read())
89
+ ref_audio_path = temp_ref_audio.name
90
+ else:
91
+ st.sidebar.write("Record your reference audio:")
92
+ reference_audio_processor = AudioProcessor()
93
+ webrtc_streamer(
94
+ key="ref_audio",
95
+ mode=WebRtcMode.SENDRECV,
96
+ client_settings=ClientSettings(rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}),
97
+ audio_receiver_size=1024,
98
+ video_processor_factory=None,
99
+ audio_processor_factory=lambda: reference_audio_processor,
100
+ )
101
+
102
+ st.title("Welcome to VocaL AI")
103
+ st.write("### How to Use")
104
+ st.write("1. Upload or record a reference audio file.")
105
+ st.write("2. Choose between text or audio input.")
106
+ st.write("3. If audio input is selected, record and submit your audio.")
107
+ st.write("4. Click 'Generate Speech' to hear the AI response in your cloned voice.")
108
+
109
+ # User Input (Text or Audio)
110
+ input_type = st.radio("Choose Input Type", ("Text", "Audio"))
111
+ user_input = None
112
+ user_audio_processor = None
113
+
114
+ if input_type == "Text":
115
+ user_input = st.text_area("Enter your text here")
116
+ else:
117
+ st.write("Record your voice:")
118
+ user_audio_processor = AudioProcessor()
119
+ webrtc_streamer(
120
+ key="user_audio",
121
+ mode=WebRtcMode.SENDRECV,
122
+ client_settings=ClientSettings(rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}),
123
+ audio_receiver_size=1024,
124
+ video_processor_factory=None,
125
+ audio_processor_factory=lambda: user_audio_processor,
126
+ )
127
+
128
+ if st.button("Generate Speech"):
129
+ # Handle Reference Audio
130
+ if reference_audio_processor:
131
+ with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio:
132
+ reference_audio_processor.save_audio(temp_ref_audio.name)
133
+ ref_audio_path = temp_ref_audio.name
134
+
135
+ if not ref_audio_path:
136
+ st.error("Please upload or record reference audio.")
137
+ return
138
+
139
+ # Handle User Input
140
+ if input_type == "Audio":
141
+ if user_audio_processor:
142
+ with NamedTemporaryFile(delete=False, suffix=".wav") as temp_user_audio:
143
+ user_audio_processor.save_audio(temp_user_audio.name)
144
+ user_input = transcribe_audio(temp_user_audio.name)
145
+ os.unlink(temp_user_audio.name)
146
+
147
+ if not user_input:
148
+ st.error("Please enter text or record audio.")
149
+ return
150
+
151
+ # Get AI Response
152
+ api_key = st.secrets["GROQ_API_KEY"]
153
+ response_text = get_llm_response(api_key, user_input)
154
+
155
+ # Generate Speech
156
+ output_audio_path = "output_speech.wav"
157
+ try:
158
+ generate_speech(response_text, output_audio_path, ref_audio_path)
159
+ os.unlink(ref_audio_path)
160
+ st.audio(output_audio_path, format="audio/wav")
161
+ except FileNotFoundError as e:
162
+ st.error(str(e))
163
+
164
+ if __name__ == "__main__":
165
+ main()