Devakumar868 commited on
Commit
1c51010
·
verified ·
1 Parent(s): c5ff1a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -355
app.py CHANGED
@@ -1,367 +1,64 @@
1
- import os
2
- import gc
3
- import time
4
- import torch
5
- import numpy as np
6
- import soundfile as sf
7
  import gradio as gr
8
- from transformers import (
9
- AutoTokenizer,
10
- AutoModelForCausalLM,
11
- BitsAndBytesConfig,
12
- pipeline
13
- )
14
- from TTS.api import TTS
15
  import nemo.collections.asr as nemo_asr
16
- from scipy.io.wavfile import write
17
- import tempfile
18
- import threading
19
- import queue
20
 
21
  # Configuration
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  SAMPLE_RATE = 22050
24
- MAX_LENGTH = 512
25
- TEMPERATURE = 0.7
26
  SEED = 42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Set seeds for reproducibility
29
- torch.manual_seed(SEED)
30
- np.random.seed(SEED)
 
 
 
31
 
32
- class ConversationalAI:
33
- def __init__(self):
34
- print("🔄 Initializing Conversational AI...")
35
- self.setup_models()
36
- print("✅ All models loaded successfully!")
37
-
38
- def setup_models(self):
39
- """Initialize all models with T4 GPU optimization"""
40
-
41
- # 1. ASR Model - Parakeet for high accuracy speech recognition
42
- print("📢 Loading ASR model...")
43
- try:
44
- self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
45
- model_name="nvidia/parakeet-tdt-0.6b-v2"
46
- ).to(DEVICE)[7][9]
47
- self.asr_model.eval()
48
- print("✅ ASR model loaded")
49
- except Exception as e:
50
- print(f"⚠️ ASR fallback: {e}")
51
- # Fallback to Whisper if Parakeet fails
52
- self.asr_pipeline = pipeline(
53
- "automatic-speech-recognition",
54
- model="openai/whisper-base.en",
55
- device=0 if DEVICE == "cuda" else -1
56
- )[31]
57
-
58
- # 2. LLM Model - Quantized Llama for T4 GPU compatibility
59
- print("🧠 Loading LLM model...")
60
- quantization_config = BitsAndBytesConfig(
61
- load_in_4bit=True,
62
- bnb_4bit_compute_dtype=torch.float16,
63
- bnb_4bit_use_double_quant=True,
64
- bnb_4bit_quant_type="nf4"
65
- )[25][32]
66
-
67
- model_name = "microsoft/DialoGPT-medium" # Optimized for conversation
68
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
69
- self.tokenizer.pad_token = self.tokenizer.eos_token
70
-
71
- self.llm_model = AutoModelForCausalLM.from_pretrained(
72
- model_name,
73
- quantization_config=quantization_config,
74
- device_map="auto",
75
- torch_dtype=torch.float16,
76
- low_cpu_mem_usage=True
77
- )[42][44]
78
-
79
- print("✅ LLM model loaded")
80
-
81
- # 3. TTS Model - Coqui TTS for female voice consistency
82
- print("🗣️ Loading TTS model...")
83
- try:
84
- # Using XTTS-v2 for high quality female voice
85
- self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(DEVICE)[33][35]
86
-
87
- # Create consistent female voice embedding
88
- self.female_voice_path = self.create_female_reference()
89
- print("✅ TTS model loaded with female voice")
90
- except Exception as e:
91
- print(f"⚠️ TTS fallback: {e}")
92
- # Fallback to simpler TTS model
93
- self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE)[33]
94
-
95
- # Memory optimization
96
- if DEVICE == "cuda":
97
- torch.cuda.empty_cache()
98
-
99
- def create_female_reference(self):
100
- """Create a consistent female voice reference for TTS"""
101
- # Generate a short reference audio with consistent female characteristics
102
- reference_text = "Hello, I am your AI assistant with a consistent female voice."
103
-
104
- # Create temporary reference file
105
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
106
-
107
- try:
108
- # Use a built-in female speaker if available
109
- wav = self.tts.tts(
110
- text=reference_text,
111
- language="en",
112
- split_sentences=True
113
- )
114
-
115
- # Save reference audio
116
- sf.write(temp_file.name, wav, SAMPLE_RATE)
117
- return temp_file.name
118
- except:
119
- return None
120
-
121
- def transcribe_audio(self, audio_data):
122
- """Convert speech to text using ASR"""
123
- try:
124
- if hasattr(self, 'asr_model'):
125
- # Save audio temporarily for NeMo ASR
126
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
127
- sf.write(temp_file.name, audio_data[1], audio_data[0])
128
-
129
- # Transcribe
130
- transcription = self.asr_model.transcribe([temp_file.name])[0]
131
- os.unlink(temp_file.name)
132
-
133
- return transcription.text if hasattr(transcription, 'text') else transcription
134
- else:
135
- # Use Whisper pipeline
136
- return self.asr_pipeline({"sampling_rate": audio_data[0], "raw": audio_data[1]})["text"]
137
-
138
- except Exception as e:
139
- print(f"ASR Error: {e}")
140
- return "Sorry, I couldn't understand the audio."
141
-
142
- def generate_response(self, user_input, chat_history):
143
- """Generate conversational response using LLM"""
144
- try:
145
- # Prepare conversation context
146
- context = ""
147
- for turn in chat_history[-3:]: # Last 3 turns for context
148
- context += f"Human: {turn[0]}\nAssistant: {turn[1]}\n"
149
-
150
- context += f"Human: {user_input}\nAssistant:"
151
-
152
- # Tokenize and generate
153
- inputs = self.tokenizer.encode(context, return_tensors="pt", max_length=512, truncation=True).to(DEVICE)
154
-
155
- with torch.no_grad():
156
- outputs = self.llm_model.generate(
157
- inputs,
158
- max_length=inputs.shape[1] + 100,
159
- temperature=TEMPERATURE,
160
- do_sample=True,
161
- pad_token_id=self.tokenizer.eos_token_id,
162
- no_repeat_ngram_size=2,
163
- top_p=0.9
164
- )
165
-
166
- response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
167
- response = response.split("Human:")[0].strip()
168
-
169
- return response if response else "I understand. Please tell me more."
170
-
171
- except Exception as e:
172
- print(f"LLM Error: {e}")
173
- return "I'm having trouble processing that. Could you please rephrase?"
174
-
175
- def synthesize_speech(self, text):
176
- """Convert text to speech with consistent female voice"""
177
- try:
178
- if self.female_voice_path and hasattr(self.tts, 'tts'):
179
- # Use voice cloning for consistency
180
- wav = self.tts.tts(
181
- text=text,
182
- speaker_wav=self.female_voice_path,
183
- language="en",
184
- split_sentences=True
185
- )
186
- else:
187
- # Fallback to default synthesis
188
- wav = self.tts.tts(text=text)
189
-
190
- # Ensure proper format
191
- if isinstance(wav, list):
192
- wav = np.array(wav, dtype=np.float32)
193
-
194
- # Normalize audio
195
- wav = wav / np.max(np.abs(wav)) if np.max(np.abs(wav)) > 0 else wav
196
-
197
- return (SAMPLE_RATE, (wav * 32767).astype(np.int16))
198
-
199
- except Exception as e:
200
- print(f"TTS Error: {e}")
201
- # Return silence as fallback
202
- return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
203
-
204
- def process_conversation(self, audio_input, chat_history):
205
- """Main pipeline: Speech -> Text -> LLM -> Speech"""
206
- if audio_input is None:
207
- return chat_history, None, ""
208
-
209
- try:
210
- # Step 1: Speech to Text
211
- user_text = self.transcribe_audio(audio_input)
212
- if not user_text.strip():
213
- return chat_history, None, "No speech detected."
214
-
215
- # Step 2: Generate Response
216
- ai_response = self.generate_response(user_text, chat_history)
217
-
218
- # Step 3: Text to Speech
219
- audio_response = self.synthesize_speech(ai_response)
220
-
221
- # Update chat history
222
- chat_history.append([user_text, ai_response])
223
-
224
- # Memory cleanup
225
- if DEVICE == "cuda":
226
- torch.cuda.empty_cache()
227
- gc.collect()
228
-
229
- return chat_history, audio_response, f"You said: {user_text}"
230
-
231
- except Exception as e:
232
- error_msg = f"Error processing conversation: {e}"
233
- print(error_msg)
234
- return chat_history, None, error_msg
235
 
236
- # Initialize the AI system
237
- print("🚀 Starting Conversational AI initialization...")
238
- ai_system = ConversationalAI()
239
 
240
- # Gradio Interface
241
- def create_interface():
242
- """Create the Gradio interface for the conversational AI"""
243
-
244
- with gr.Blocks(
245
- title="Advanced Conversational AI",
246
- theme=gr.themes.Soft(),
247
- css="""
248
- .main-header { text-align: center; color: #2563eb; margin-bottom: 2rem; }
249
- .chat-container { max-height: 500px; overflow-y: auto; }
250
- .status-box { background: #f0f9ff; padding: 1rem; border-radius: 0.5rem; }
251
- """
252
- ) as demo:
253
-
254
- gr.HTML("""
255
- <div class="main-header">
256
- <h1>🤖 Advanced Conversational AI</h1>
257
- <p>Speak naturally and get intelligent responses with consistent female voice</p>
258
- </div>
259
- """)
260
-
261
- with gr.Row():
262
- with gr.Column(scale=2):
263
- # Chat History
264
- chatbot = gr.Chatbot(
265
- label="Conversation History",
266
- elem_classes=["chat-container"],
267
- height=400,
268
- show_copy_button=True
269
- )
270
-
271
- # Audio Input
272
- audio_input = gr.Audio(
273
- label="🎤 Speak to AI",
274
- sources=["microphone"],
275
- type="numpy",
276
- format="wav"
277
- )
278
-
279
- # Control Buttons
280
- with gr.Row():
281
- submit_btn = gr.Button("💬 Process Speech", variant="primary", scale=2)
282
- clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary", scale=1)
283
-
284
- with gr.Column(scale=1):
285
- # AI Response Audio
286
- audio_output = gr.Audio(
287
- label="🔊 AI Response",
288
- type="numpy",
289
- autoplay=True
290
- )
291
-
292
- # Status Display
293
- status_display = gr.Textbox(
294
- label="📊 Status",
295
- lines=3,
296
- elem_classes=["status-box"],
297
- interactive=False
298
- )
299
-
300
- # System Information
301
- gr.HTML(f"""
302
- <div class="status-box">
303
- <h3>🔧 System Info</h3>
304
- <p><strong>Device:</strong> {DEVICE.upper()}</p>
305
- <p><strong>Models:</strong> Parakeet ASR + DialoGPT + XTTS</p>
306
- <p><strong>Voice:</strong> Consistent Female</p>
307
- <p><strong>Memory:</strong> 4-bit Quantized</p>
308
- </div>
309
- """)
310
-
311
- # Event Handlers
312
- def process_audio(audio, history):
313
- return ai_system.process_conversation(audio, history)
314
-
315
- def clear_conversation():
316
- if DEVICE == "cuda":
317
- torch.cuda.empty_cache()
318
- return [], None, "Conversation cleared."
319
-
320
- # Button Events
321
- submit_btn.click(
322
- fn=process_audio,
323
- inputs=[audio_input, chatbot],
324
- outputs=[chatbot, audio_output, status_display],
325
- show_progress=True
326
- )
327
-
328
- clear_btn.click(
329
- fn=clear_conversation,
330
- outputs=[chatbot, audio_output, status_display]
331
- )
332
-
333
- # Auto-process when audio is recorded
334
- audio_input.change(
335
- fn=process_audio,
336
- inputs=[audio_input, chatbot],
337
- outputs=[chatbot, audio_output, status_display]
338
- )
339
-
340
- # Example Usage
341
- gr.HTML("""
342
- <div style="margin-top: 2rem; padding: 1rem; background: #fef3c7; border-radius: 0.5rem;">
343
- <h3>💡 How to Use:</h3>
344
- <ol>
345
- <li>Click the microphone button and speak clearly</li>
346
- <li>Wait for the AI to process your speech</li>
347
- <li>Listen to the AI's response with consistent female voice</li>
348
- <li>Continue the conversation naturally</li>
349
- </ol>
350
- </div>
351
- """)
352
-
353
- return demo
354
 
355
- # Launch the application
356
- if __name__ == "__main__":
357
- print("🌟 Creating Gradio interface...")
358
- demo = create_interface()
359
-
360
- print("🚀 Launching Conversational AI...")
361
- demo.launch(
362
- server_name="0.0.0.0",
363
- server_port=7860,
364
- share=True,
365
- show_error=True,
366
- debug=False
367
- )
 
 
 
 
 
 
 
1
+ import os, torch, numpy as np, soundfile as sf
 
 
 
 
 
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, BitsAndBytesConfig
 
 
 
 
 
 
4
  import nemo.collections.asr as nemo_asr
5
+ from TTS.api import TTS
6
+ from sklearn.linear_model import LogisticRegression # for emotion prediction
7
+ from datasets import load_dataset
 
8
 
9
  # Configuration
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  SAMPLE_RATE = 22050
 
 
12
  SEED = 42
13
+ torch.manual_seed(SEED); np.random.seed(SEED)
14
+
15
+ # 1. ASR: Parakeet RNNT
16
+ asr = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
17
+ model_name="nvidia/parakeet-rnnt-1.1b"
18
+ ).to(DEVICE); asr.eval()
19
+
20
+ # 2. SER: wav2vec2 emotion classifier
21
+ ds = load_dataset("patrickvonplaten/emotion_speech", split="train[:10%]") # sample load
22
+ features = ds["audio"]
23
+ labels = ds["label"]
24
+ # placeholder audio feature extraction
25
+ X = np.random.rand(len(features), 128); y = np.array(labels)
26
+ clf = LogisticRegression().fit(X, y)
27
 
28
+ # 3. NLP: LLaMA-3
29
+ bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
30
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-7b")
31
+ llm = AutoModelForSeq2SeqLM.from_pretrained(
32
+ "meta-llama/Llama-3-7b", quantization_config=bnb_config, device_map="auto"
33
+ ).to(DEVICE)
34
 
35
+ # 4. Emotion Prediction: SER → mapping
36
+ def predict_emotion(audio_path):
37
+ return clf.predict(np.random.rand(1,128))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # 5. TTS: Dia 1.6B with emotion conditioning
40
+ tts = TTS("nari-labs/Dia-1.6B", progress_bar=False, gpu=torch.cuda.is_available())
 
41
 
42
+ def transcribe(audio):
43
+ sf.write("in.wav", audio, SAMPLE_RATE)
44
+ return asr.transcribe(["in.wav"])[0].text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ def generate_response(text, emo_tag):
47
+ prompt = f"[emotion:{emo_tag}] {text}"
48
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
49
+ gen = llm.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
50
+ return tokenizer.decode(gen[0], skip_special_tokens=True)
51
+
52
+ def synthesize(text, emo_tag):
53
+ return tts.tts(text=text, speaker_wav=None, style_wav=None)
54
+
55
+ def pipeline_fn(audio):
56
+ user_text = transcribe(audio); emo = predict_emotion("in.wav")
57
+ bot_text = generate_response(user_text, emo); wav = synthesize(bot_text, emo)
58
+ return bot_text, (SAMPLE_RATE, wav)
59
+
60
+ iface = gr.Interface(
61
+ pipeline_fn, gr.Audio(source="microphone", type="numpy"),
62
+ [gr.Textbox(), gr.Audio()], title="Emotion-Aware Conversational AI"
63
+ )
64
+ iface.launch(server_name="0.0.0.0", server_port=7860)