Devakumar868 commited on
Commit
dbc05eb
Β·
verified Β·
1 Parent(s): 8dbae03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -68
app.py CHANGED
@@ -10,6 +10,7 @@ import time
10
  from datetime import datetime
11
  import os
12
  import sys
 
13
 
14
  # Import with enhanced error handling
15
  try:
@@ -56,6 +57,13 @@ class ConversationManager:
56
  self.history = []
57
  self.current_emotion = "neutral"
58
 
 
 
 
 
 
 
 
59
  def check_system_info():
60
  """Check system capabilities"""
61
  print("πŸ” System Information:")
@@ -66,14 +74,20 @@ def check_system_info():
66
  print(f"βœ… CUDA: {torch.cuda.get_device_name()}")
67
  print(f"πŸ’Ύ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
68
  print(f"πŸ”₯ CUDA Version: {torch.version.cuda}")
 
 
 
 
 
69
  else:
70
  print("⚠️ CUDA not available, using CPU")
71
 
72
  def load_models():
73
- """Load all models with enhanced error handling"""
74
  global asr_pipe, qwen_model, qwen_tokenizer, tts_model, tts_type
75
 
76
  print("πŸš€ Loading Maya AI models...")
 
77
 
78
  # Load ASR model (Whisper)
79
  print("🎀 Loading Whisper for ASR...")
@@ -85,11 +99,12 @@ def load_models():
85
  device=0 if torch.cuda.is_available() else -1
86
  )
87
  print("βœ… Whisper ASR loaded successfully!")
 
88
  except Exception as e:
89
  print(f"❌ Error loading Whisper: {e}")
90
  return False
91
 
92
- # Load Qwen model
93
  print("🧠 Loading Qwen2.5-1.5B for conversation...")
94
  try:
95
  model_name = "Qwen/Qwen2.5-1.5B-Instruct"
@@ -102,23 +117,36 @@ def load_models():
102
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
103
  device_map="auto" if torch.cuda.is_available() else None,
104
  trust_remote_code=True,
105
- low_cpu_mem_usage=True
 
106
  )
107
  print("βœ… Qwen loaded successfully!")
 
108
  except Exception as e:
109
  print(f"❌ Error loading Qwen: {e}")
110
  return False
111
 
112
- # Load Dia TTS
113
  if DIA_AVAILABLE:
114
  try:
115
- print("Attempting to load Dia TTS...")
 
 
 
 
116
  tts_model = Dia.from_pretrained(
117
  "nari-labs/Dia-1.6B",
118
- compute_dtype="float16" if torch.cuda.is_available() else "float32"
 
119
  )
 
 
 
 
 
120
  tts_type = "dia"
121
  print("βœ… Dia TTS loaded successfully!")
 
122
  return True
123
  except Exception as e:
124
  print(f"⚠️ Dia TTS failed to load: {e}")
@@ -199,7 +227,7 @@ def speech_to_text_with_emotion(audio_input):
199
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
200
 
201
  print("πŸ”„ Running Whisper ASR...")
202
- result = asr_pipe(audio_data)
203
 
204
  transcription = result['text'].strip()
205
  print(f"Transcription: '{transcription}'")
@@ -217,8 +245,11 @@ def speech_to_text_with_emotion(audio_input):
217
  return "Sorry, I couldn't understand that. Please try again.", "neutral"
218
 
219
  def generate_contextual_response(user_input, emotion, conversation_manager):
220
- """Enhanced response generation"""
221
  try:
 
 
 
222
  context = conversation_manager.get_context()
223
 
224
  emotional_prompts = {
@@ -237,7 +268,7 @@ Previous context: {context}
237
  User emotion: {emotion}
238
 
239
  Guidelines:
240
- - Keep responses concise (1-2 sentences)
241
  - Be natural and conversational
242
  - Show empathy and understanding
243
  - Provide helpful responses
@@ -259,12 +290,13 @@ Guidelines:
259
  with torch.no_grad():
260
  generated_ids = qwen_model.generate(
261
  model_inputs.input_ids,
262
- max_new_tokens=100,
263
  do_sample=True,
264
  temperature=0.7,
265
  top_p=0.9,
266
  repetition_penalty=1.1,
267
- pad_token_id=qwen_tokenizer.eos_token_id
 
268
  )
269
 
270
  generated_ids = [
@@ -277,6 +309,9 @@ Guidelines:
277
  if response.startswith("Maya:"):
278
  response = response[5:].strip()
279
 
 
 
 
280
  return response
281
 
282
  except Exception as e:
@@ -284,64 +319,88 @@ Guidelines:
284
  return "I'm sorry, I'm having trouble processing that right now."
285
 
286
  def text_to_speech_emotional(text, emotion="neutral"):
287
- """FIXED TTS with proper audio format for Gradio"""
288
  try:
289
  if tts_model is None:
290
  print(f"πŸ”Š Maya says ({emotion}): {text}")
291
  return None
292
 
293
- # Clear GPU cache
294
- if torch.cuda.is_available():
295
- torch.cuda.empty_cache()
296
 
297
  if tts_type == "dia":
 
298
  emotional_markers = {
299
- "happy": "(excited) ",
300
- "sad": "(sad) ",
301
- "angry": "(calm) ",
302
- "surprised": "(surprised) ",
303
  "neutral": ""
304
  }
305
 
306
- # Enhanced text for Dia
307
- enhanced_text = f"[S1] {emotional_markers.get(emotion, '')}{text}"
308
-
309
- # Add pauses for natural speech
310
- if len(text) > 50:
311
- enhanced_text = enhanced_text.replace(". ", ". (pause) ")
312
- enhanced_text = enhanced_text.replace("! ", "! (pause) ")
313
- enhanced_text = enhanced_text.replace("? ", "? (pause) ")
314
-
315
- print(f"Generating Dia TTS for: {enhanced_text}")
316
-
317
- with torch.no_grad():
318
- audio_output = tts_model.generate(
319
- enhanced_text,
320
- use_torch_compile=False,
321
- verbose=False
322
- )
323
-
324
- # FIXED: Proper audio processing for Gradio
325
- if isinstance(audio_output, torch.Tensor):
326
- audio_output = audio_output.cpu().numpy()
327
 
328
- # Ensure audio is in the right format
329
- if len(audio_output.shape) > 1:
330
- audio_output = audio_output.squeeze()
331
 
332
- # Normalize audio properly
333
- if len(audio_output) > 0:
334
- max_val = np.max(np.abs(audio_output))
335
- if max_val > 0:
336
- audio_output = audio_output / max_val * 0.95
337
 
338
- # CRITICAL FIX: Ensure audio is float32 and in correct range
339
- audio_output = audio_output.astype(np.float32)
340
-
341
- print(f"βœ… Generated audio: shape={audio_output.shape}, dtype={audio_output.dtype}, range=[{audio_output.min():.3f}, {audio_output.max():.3f}]")
342
 
343
- # Return in format Gradio expects: (sample_rate, audio_array)
344
- return (44100, audio_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  else:
347
  print(f"πŸ”Š Maya says ({emotion}): {text}")
@@ -349,6 +408,7 @@ def text_to_speech_emotional(text, emotion="neutral"):
349
 
350
  except Exception as e:
351
  print(f"❌ Error in TTS: {e}")
 
352
  print(f"πŸ”Š Maya says ({emotion}): {text}")
353
  return None
354
 
@@ -358,19 +418,22 @@ conv_manager = ConversationManager()
358
  def start_call():
359
  """Initialize call and return greeting"""
360
  conv_manager.clear()
361
- greeting_text = "Hello! I'm Maya, your AI assistant. How can I help you today?"
 
 
362
  greeting_audio = text_to_speech_emotional(greeting_text, "happy")
363
 
364
  tts_status = f"Using {tts_type.upper()} TTS" if tts_type != "none" else "Text-only mode"
365
  return greeting_audio, greeting_text, f"πŸ“ž Call started! Maya is ready. {tts_status}"
366
 
367
  def process_conversation(audio_input):
368
- """Main conversation processing pipeline"""
369
  if audio_input is None:
370
  return None, "Please record some audio first.", "", "❌ No audio input received."
371
 
372
  try:
373
  print("πŸ”„ Processing conversation...")
 
374
 
375
  # STT + Emotion Detection
376
  user_text, emotion = speech_to_text_with_emotion(audio_input)
@@ -392,13 +455,19 @@ def process_conversation(audio_input):
392
  # Update history
393
  conv_manager.add_exchange(user_text, ai_response, emotion)
394
 
395
- status = f"βœ… Success! | Emotion: {emotion} | Exchange: {len(conv_manager.history)}/5 | TTS: {tts_type.upper()}"
 
 
 
 
 
396
 
397
  return response_audio, ai_response, user_text, status
398
 
399
  except Exception as e:
400
  error_msg = f"❌ Error: {str(e)}"
401
  print(error_msg)
 
402
  return None, "I'm sorry, I encountered an error. Please try again.", "", error_msg
403
 
404
  def get_conversation_history():
@@ -416,15 +485,16 @@ def get_conversation_history():
416
  return history_text
417
 
418
  def end_call():
419
- """End call"""
420
  farewell_text = "Thank you for talking with me! Have a wonderful day!"
421
  farewell_audio = text_to_speech_emotional(farewell_text, "happy")
422
  conv_manager.clear()
 
423
 
424
  return farewell_audio, farewell_text, "πŸ“žβŒ Call ended. Thank you!"
425
 
426
  def create_interface():
427
- """Create Gradio interface with FIXED audio components"""
428
  with gr.Blocks(
429
  title="Maya AI - Speech-to-Speech Assistant",
430
  theme=gr.themes.Soft()
@@ -462,14 +532,18 @@ def create_interface():
462
 
463
  with gr.Column(scale=2):
464
  gr.HTML("<h3>πŸ”Š Maya's Response</h3>")
465
- # FIXED: Audio component with proper settings
466
  response_audio = gr.Audio(
467
  label="Maya's Voice Response",
468
  type="numpy",
469
  interactive=False,
470
- autoplay=True, # Enable autoplay
471
  show_download_button=True,
472
- show_share_button=False
 
 
 
 
473
  )
474
 
475
  with gr.Row():
@@ -515,7 +589,7 @@ def create_interface():
515
  outputs=[history_display]
516
  )
517
 
518
- # Instructions
519
  gr.HTML("""
520
  <div style="margin-top: 30px; padding: 25px; background: #f8f9fa; border-radius: 15px;">
521
  <h3>πŸ’‘ How to Use Maya AI:</h3>
@@ -529,12 +603,12 @@ def create_interface():
529
  </ol>
530
 
531
  <div style="margin-top: 20px; padding: 15px; background: #d1ecf1; border-radius: 8px;">
532
- <p><strong>πŸ’‘ Pro Tips:</strong></p>
533
  <ul>
534
- <li>Speak clearly and close to your microphone</li>
535
- <li>Record for at least 2-3 seconds</li>
536
- <li>Use a quiet environment for best results</li>
537
- <li>Maya detects emotions and responds accordingly!</li>
538
  </ul>
539
  </div>
540
  </div>
 
10
  from datetime import datetime
11
  import os
12
  import sys
13
+ import gc
14
 
15
  # Import with enhanced error handling
16
  try:
 
57
  self.history = []
58
  self.current_emotion = "neutral"
59
 
60
+ def optimize_gpu_memory():
61
+ """Optimize GPU memory usage"""
62
+ if torch.cuda.is_available():
63
+ torch.cuda.empty_cache()
64
+ torch.cuda.synchronize()
65
+ gc.collect()
66
+
67
  def check_system_info():
68
  """Check system capabilities"""
69
  print("πŸ” System Information:")
 
74
  print(f"βœ… CUDA: {torch.cuda.get_device_name()}")
75
  print(f"πŸ’Ύ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
76
  print(f"πŸ”₯ CUDA Version: {torch.version.cuda}")
77
+
78
+ # Check current memory usage
79
+ allocated = torch.cuda.memory_allocated() / 1e9
80
+ cached = torch.cuda.memory_reserved() / 1e9
81
+ print(f"πŸ“Š Current GPU Usage: {allocated:.1f}GB allocated, {cached:.1f}GB cached")
82
  else:
83
  print("⚠️ CUDA not available, using CPU")
84
 
85
  def load_models():
86
+ """Load all models with enhanced memory management"""
87
  global asr_pipe, qwen_model, qwen_tokenizer, tts_model, tts_type
88
 
89
  print("πŸš€ Loading Maya AI models...")
90
+ optimize_gpu_memory()
91
 
92
  # Load ASR model (Whisper)
93
  print("🎀 Loading Whisper for ASR...")
 
99
  device=0 if torch.cuda.is_available() else -1
100
  )
101
  print("βœ… Whisper ASR loaded successfully!")
102
+ optimize_gpu_memory()
103
  except Exception as e:
104
  print(f"❌ Error loading Whisper: {e}")
105
  return False
106
 
107
+ # Load Qwen model with memory optimization
108
  print("🧠 Loading Qwen2.5-1.5B for conversation...")
109
  try:
110
  model_name = "Qwen/Qwen2.5-1.5B-Instruct"
 
117
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
118
  device_map="auto" if torch.cuda.is_available() else None,
119
  trust_remote_code=True,
120
+ low_cpu_mem_usage=True,
121
+ max_memory={0: "6GB"} if torch.cuda.is_available() else None # Limit Qwen memory
122
  )
123
  print("βœ… Qwen loaded successfully!")
124
+ optimize_gpu_memory()
125
  except Exception as e:
126
  print(f"❌ Error loading Qwen: {e}")
127
  return False
128
 
129
+ # Load Dia TTS with optimized settings
130
  if DIA_AVAILABLE:
131
  try:
132
+ print("Attempting to load Dia TTS with optimized settings...")
133
+
134
+ # Clear memory before loading Dia
135
+ optimize_gpu_memory()
136
+
137
  tts_model = Dia.from_pretrained(
138
  "nari-labs/Dia-1.6B",
139
+ compute_dtype="float16" if torch.cuda.is_available() else "float32",
140
+ low_cpu_mem_usage=True
141
  )
142
+
143
+ # Move to GPU if available
144
+ if torch.cuda.is_available():
145
+ tts_model = tts_model.cuda()
146
+
147
  tts_type = "dia"
148
  print("βœ… Dia TTS loaded successfully!")
149
+ optimize_gpu_memory()
150
  return True
151
  except Exception as e:
152
  print(f"⚠️ Dia TTS failed to load: {e}")
 
227
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
228
 
229
  print("πŸ”„ Running Whisper ASR...")
230
+ result = asr_pipe(audio_data, language='en') # Force English to avoid language detection
231
 
232
  transcription = result['text'].strip()
233
  print(f"Transcription: '{transcription}'")
 
245
  return "Sorry, I couldn't understand that. Please try again.", "neutral"
246
 
247
  def generate_contextual_response(user_input, emotion, conversation_manager):
248
+ """Enhanced response generation with memory optimization"""
249
  try:
250
+ # Clear GPU cache before generation
251
+ optimize_gpu_memory()
252
+
253
  context = conversation_manager.get_context()
254
 
255
  emotional_prompts = {
 
268
  User emotion: {emotion}
269
 
270
  Guidelines:
271
+ - Keep responses very concise (1 sentence maximum)
272
  - Be natural and conversational
273
  - Show empathy and understanding
274
  - Provide helpful responses
 
290
  with torch.no_grad():
291
  generated_ids = qwen_model.generate(
292
  model_inputs.input_ids,
293
+ max_new_tokens=50, # Reduced for shorter responses
294
  do_sample=True,
295
  temperature=0.7,
296
  top_p=0.9,
297
  repetition_penalty=1.1,
298
+ pad_token_id=qwen_tokenizer.eos_token_id,
299
+ attention_mask=model_inputs.attention_mask # Fix attention mask warning
300
  )
301
 
302
  generated_ids = [
 
309
  if response.startswith("Maya:"):
310
  response = response[5:].strip()
311
 
312
+ # Clear cache after generation
313
+ optimize_gpu_memory()
314
+
315
  return response
316
 
317
  except Exception as e:
 
319
  return "I'm sorry, I'm having trouble processing that right now."
320
 
321
  def text_to_speech_emotional(text, emotion="neutral"):
322
+ """FIXED TTS with enhanced Dia configuration and memory management"""
323
  try:
324
  if tts_model is None:
325
  print(f"πŸ”Š Maya says ({emotion}): {text}")
326
  return None
327
 
328
+ # Aggressive memory cleanup before TTS
329
+ optimize_gpu_memory()
 
330
 
331
  if tts_type == "dia":
332
+ # Simplified emotional markers for better audio quality
333
  emotional_markers = {
334
+ "happy": "", # Remove complex markers that might cause artifacts
335
+ "sad": "",
336
+ "angry": "",
337
+ "surprised": "",
338
  "neutral": ""
339
  }
340
 
341
+ # Simplified text processing for Dia - NO COMPLEX MARKERS
342
+ # Keep it simple to avoid audio artifacts
343
+ enhanced_text = f"[S1] {text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
+ # Remove pauses that might cause artifacts
346
+ # enhanced_text = enhanced_text.replace("(pause)", "")
 
347
 
348
+ # Limit text length to prevent memory issues
349
+ if len(enhanced_text) > 200:
350
+ enhanced_text = enhanced_text[:200] + "..."
 
 
351
 
352
+ print(f"Generating Dia TTS for: {enhanced_text}")
 
 
 
353
 
354
+ try:
355
+ with torch.no_grad():
356
+ # Use more conservative settings for T4
357
+ audio_output = tts_model.generate(
358
+ enhanced_text,
359
+ use_torch_compile=False,
360
+ verbose=False,
361
+ # Add these parameters for better quality
362
+ temperature=0.7,
363
+ top_p=0.9
364
+ )
365
+
366
+ # Enhanced audio processing
367
+ if isinstance(audio_output, torch.Tensor):
368
+ audio_output = audio_output.cpu().numpy()
369
+
370
+ # Ensure proper audio format
371
+ if len(audio_output.shape) > 1:
372
+ audio_output = audio_output.squeeze()
373
+
374
+ # More conservative normalization
375
+ if len(audio_output) > 0:
376
+ # Remove DC offset
377
+ audio_output = audio_output - np.mean(audio_output)
378
+
379
+ # Gentle normalization to prevent clipping
380
+ max_val = np.max(np.abs(audio_output))
381
+ if max_val > 0:
382
+ audio_output = audio_output / max_val * 0.8 # More conservative scaling
383
+
384
+ # Ensure correct data type
385
+ audio_output = audio_output.astype(np.float32)
386
+
387
+ # Validate audio output
388
+ if np.any(np.isnan(audio_output)) or np.any(np.isinf(audio_output)):
389
+ print("❌ Audio contains NaN or Inf values, regenerating...")
390
+ return None
391
+
392
+ print(f"βœ… Generated audio: shape={audio_output.shape}, dtype={audio_output.dtype}, range=[{audio_output.min():.3f}, {audio_output.max():.3f}]")
393
+
394
+ # Clear memory after generation
395
+ optimize_gpu_memory()
396
+
397
+ # Return audio with correct sample rate for Dia
398
+ return (44100, audio_output)
399
+
400
+ except Exception as e:
401
+ print(f"❌ Error in Dia generation: {e}")
402
+ optimize_gpu_memory()
403
+ return None
404
 
405
  else:
406
  print(f"πŸ”Š Maya says ({emotion}): {text}")
 
408
 
409
  except Exception as e:
410
  print(f"❌ Error in TTS: {e}")
411
+ optimize_gpu_memory()
412
  print(f"πŸ”Š Maya says ({emotion}): {text}")
413
  return None
414
 
 
418
  def start_call():
419
  """Initialize call and return greeting"""
420
  conv_manager.clear()
421
+ optimize_gpu_memory()
422
+
423
+ greeting_text = "Hello! I'm Maya. How can I help you today?" # Shorter greeting
424
  greeting_audio = text_to_speech_emotional(greeting_text, "happy")
425
 
426
  tts_status = f"Using {tts_type.upper()} TTS" if tts_type != "none" else "Text-only mode"
427
  return greeting_audio, greeting_text, f"πŸ“ž Call started! Maya is ready. {tts_status}"
428
 
429
  def process_conversation(audio_input):
430
+ """Main conversation processing pipeline with memory management"""
431
  if audio_input is None:
432
  return None, "Please record some audio first.", "", "❌ No audio input received."
433
 
434
  try:
435
  print("πŸ”„ Processing conversation...")
436
+ optimize_gpu_memory()
437
 
438
  # STT + Emotion Detection
439
  user_text, emotion = speech_to_text_with_emotion(audio_input)
 
455
  # Update history
456
  conv_manager.add_exchange(user_text, ai_response, emotion)
457
 
458
+ # Memory status
459
+ if torch.cuda.is_available():
460
+ allocated = torch.cuda.memory_allocated() / 1e9
461
+ status = f"βœ… Success! | Emotion: {emotion} | Exchange: {len(conv_manager.history)}/5 | GPU: {allocated:.1f}GB"
462
+ else:
463
+ status = f"βœ… Success! | Emotion: {emotion} | Exchange: {len(conv_manager.history)}/5"
464
 
465
  return response_audio, ai_response, user_text, status
466
 
467
  except Exception as e:
468
  error_msg = f"❌ Error: {str(e)}"
469
  print(error_msg)
470
+ optimize_gpu_memory()
471
  return None, "I'm sorry, I encountered an error. Please try again.", "", error_msg
472
 
473
  def get_conversation_history():
 
485
  return history_text
486
 
487
  def end_call():
488
+ """End call with memory cleanup"""
489
  farewell_text = "Thank you for talking with me! Have a wonderful day!"
490
  farewell_audio = text_to_speech_emotional(farewell_text, "happy")
491
  conv_manager.clear()
492
+ optimize_gpu_memory()
493
 
494
  return farewell_audio, farewell_text, "πŸ“žβŒ Call ended. Thank you!"
495
 
496
  def create_interface():
497
+ """Create Gradio interface with enhanced audio settings"""
498
  with gr.Blocks(
499
  title="Maya AI - Speech-to-Speech Assistant",
500
  theme=gr.themes.Soft()
 
532
 
533
  with gr.Column(scale=2):
534
  gr.HTML("<h3>πŸ”Š Maya's Response</h3>")
535
+ # Enhanced audio component with better settings
536
  response_audio = gr.Audio(
537
  label="Maya's Voice Response",
538
  type="numpy",
539
  interactive=False,
540
+ autoplay=True,
541
  show_download_button=True,
542
+ show_share_button=False,
543
+ waveform_options=gr.WaveformOptions(
544
+ waveform_color="#01C6FF",
545
+ waveform_progress_color="#0066CC"
546
+ )
547
  )
548
 
549
  with gr.Row():
 
589
  outputs=[history_display]
590
  )
591
 
592
+ # Enhanced instructions
593
  gr.HTML("""
594
  <div style="margin-top: 30px; padding: 25px; background: #f8f9fa; border-radius: 15px;">
595
  <h3>πŸ’‘ How to Use Maya AI:</h3>
 
603
  </ol>
604
 
605
  <div style="margin-top: 20px; padding: 15px; background: #d1ecf1; border-radius: 8px;">
606
+ <p><strong>πŸ”§ Troubleshooting Audio Issues:</strong></p>
607
  <ul>
608
+ <li>If audio sounds weird, try refreshing the page</li>
609
+ <li>Use the download button to save and test audio files</li>
610
+ <li>Speak in a quiet environment for best results</li>
611
+ <li>Keep responses short for better audio quality</li>
612
  </ul>
613
  </div>
614
  </div>