Devakumar868 commited on
Commit
322ba51
Β·
verified Β·
1 Parent(s): cf427d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -47
app.py CHANGED
@@ -1,19 +1,23 @@
1
  import os, torch, numpy as np, soundfile as sf, gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
3
  import nemo.collections.asr as nemo_asr
4
- from TTS.api import TTS # Note: using TTS, not coqui_tts
5
  from sklearn.linear_model import LogisticRegression
6
  from datasets import load_dataset
7
  import tempfile
 
8
 
9
  # Configuration
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  SEED = 42; SAMPLE_RATE = 22050; TEMPERATURE = 0.7
12
  torch.manual_seed(SEED); np.random.seed(SEED)
13
 
14
- print(f"Using device: {DEVICE}")
15
- print(f"NumPy version: {np.__version__}")
16
- print(f"PyTorch version: {torch.__version__}")
 
 
 
17
 
18
  class ConversationalAI:
19
  def __init__(self):
@@ -28,28 +32,32 @@ class ConversationalAI:
28
  self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
29
  "nvidia/parakeet-rnnt-1.1b"
30
  ).to(DEVICE).eval()
31
- print("βœ… ASR model loaded")
32
  except Exception as e:
33
- print(f"⚠️ ASR error: {e}")
34
- # Fallback to Whisper
35
  self.asr_pipeline = pipeline(
36
  "automatic-speech-recognition",
37
  model="openai/whisper-base.en",
38
  device=0 if DEVICE == "cuda" else -1
39
  )
 
40
 
41
- # 2. SER: Simple emotion classifier (demo)
42
  print("🎭 Setting up emotion recognition...")
43
- # Create dummy SER for demo
44
  X_demo = np.random.rand(100, 128)
45
- y_demo = np.random.randint(0, 5, 100) # 5 emotion classes
46
  self.ser_clf = LogisticRegression().fit(X_demo, y_demo)
 
 
47
 
48
- # 3. LLM: Quantized model for conversation
49
- print("🧠 Loading LLM model...")
50
  bnb_cfg = BitsAndBytesConfig(
51
  load_in_4bit=True,
52
- bnb_4bit_compute_dtype=torch.float16
 
 
53
  )
54
 
55
  model_name = "microsoft/DialoGPT-medium"
@@ -60,15 +68,16 @@ class ConversationalAI:
60
  model_name,
61
  quantization_config=bnb_cfg,
62
  device_map="auto",
63
- torch_dtype=torch.float16
 
64
  )
65
- print("βœ… LLM model loaded")
66
 
67
- # 4. TTS: Coqui TTS for speech synthesis
68
- print("πŸ—£οΈ Loading TTS model...")
69
  try:
70
  self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE)
71
- print("βœ… TTS model loaded")
72
  except Exception as e:
73
  print(f"⚠️ TTS error: {e}")
74
  self.tts = None
@@ -76,28 +85,36 @@ class ConversationalAI:
76
  # Memory cleanup
77
  if DEVICE == "cuda":
78
  torch.cuda.empty_cache()
 
79
 
80
  def transcribe(self, audio):
 
81
  try:
82
  if hasattr(self, 'asr_model'):
 
83
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
84
  sf.write(temp_file.name, audio[1], audio[0])
85
  transcription = self.asr_model.transcribe([temp_file.name])[0]
86
  os.unlink(temp_file.name)
87
  return transcription.text if hasattr(transcription, 'text') else str(transcription)
88
  else:
 
89
  return self.asr_pipeline({"sampling_rate": audio[0], "raw": audio[1]})["text"]
90
  except Exception as e:
91
  print(f"ASR Error: {e}")
92
  return "Sorry, I couldn't understand the audio."
93
 
94
  def predict_emotion(self):
95
- # Simple emotion prediction (demo)
96
- return self.ser_clf.predict(np.random.rand(1, 128))[0]
 
97
 
98
- def generate_response(self, text, emo):
 
99
  try:
100
- prompt = f"Human: {text}\nAssistant:"
 
 
101
  inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True).to(DEVICE)
102
 
103
  with torch.no_grad():
@@ -106,89 +123,169 @@ class ConversationalAI:
106
  max_length=inputs.shape[1] + 100,
107
  temperature=TEMPERATURE,
108
  do_sample=True,
109
- pad_token_id=self.tokenizer.eos_token_id
 
 
110
  )
111
 
112
  response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
113
- return response.split("Human:")[0].strip() or "I understand. Please tell me more."
 
 
114
  except Exception as e:
115
  print(f"LLM Error: {e}")
116
  return "I'm having trouble processing that. Could you please rephrase?"
117
 
118
  def synthesize(self, text):
 
119
  try:
120
  if self.tts:
121
  wav = self.tts.tts(text=text)
122
  if isinstance(wav, list):
123
  wav = np.array(wav, dtype=np.float32)
 
124
  wav = wav / np.max(np.abs(wav)) if np.max(np.abs(wav)) > 0 else wav
125
  return (SAMPLE_RATE, (wav * 32767).astype(np.int16))
126
  else:
 
127
  return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
128
  except Exception as e:
129
  print(f"TTS Error: {e}")
130
  return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
131
 
132
  def process_conversation(self, audio_input, chat_history):
 
133
  if audio_input is None:
134
  return chat_history, None, ""
135
 
136
  try:
137
- # Pipeline: ASR -> SER -> LLM -> TTS
138
  user_text = self.transcribe(audio_input)
139
  if not user_text.strip():
140
  return chat_history, None, "No speech detected."
141
 
142
- emo = self.predict_emotion()
143
- ai_response = self.generate_response(user_text, emo)
 
 
 
 
 
144
  audio_response = self.synthesize(ai_response)
145
 
 
146
  chat_history.append([user_text, ai_response])
147
 
 
148
  if DEVICE == "cuda":
149
  torch.cuda.empty_cache()
 
150
 
151
- return chat_history, audio_response, f"You said: {user_text}"
 
152
  except Exception as e:
153
- error_msg = f"Error: {e}"
154
  print(error_msg)
155
  return chat_history, None, error_msg
156
 
157
  # Initialize AI system
158
- print("πŸš€ Starting initialization...")
159
  ai_system = ConversationalAI()
160
 
161
- # Gradio interface
162
  def create_interface():
163
- with gr.Blocks(title="Emotion-Aware Conversational AI") as demo:
164
- gr.HTML("<h1>πŸ€– Emotion-Aware Conversational AI</h1>")
 
 
 
 
 
 
 
 
 
165
 
166
  with gr.Row():
167
- with gr.Column():
168
- chatbot = gr.Chatbot(label="Conversation", height=400)
169
- audio_input = gr.Audio(label="🎀 Speak", sources=["microphone"], type="numpy")
 
 
 
 
 
 
 
 
 
 
170
 
171
  with gr.Row():
172
- submit_btn = gr.Button("πŸ’¬ Process", variant="primary")
173
- clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
174
 
175
- with gr.Column():
176
- audio_output = gr.Audio(label="πŸ”Š AI Response", autoplay=True)
177
- status = gr.Textbox(label="πŸ“Š Status", lines=3, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  def process_audio(audio, history):
180
  return ai_system.process_conversation(audio, history)
181
 
182
- def clear_chat():
 
 
 
183
  return [], None, "Conversation cleared."
184
 
185
- submit_btn.click(process_audio, [audio_input, chatbot], [chatbot, audio_output, status])
186
- clear_btn.click(clear_chat, outputs=[chatbot, audio_output, status])
187
- audio_input.change(process_audio, [audio_input, chatbot], [chatbot, audio_output, status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  return demo
190
 
191
- # Launch
192
  if __name__ == "__main__":
 
193
  demo = create_interface()
194
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
 
1
  import os, torch, numpy as np, soundfile as sf, gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
3
  import nemo.collections.asr as nemo_asr
4
+ from TTS.api import TTS
5
  from sklearn.linear_model import LogisticRegression
6
  from datasets import load_dataset
7
  import tempfile
8
+ import gc
9
 
10
  # Configuration
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
  SEED = 42; SAMPLE_RATE = 22050; TEMPERATURE = 0.7
13
  torch.manual_seed(SEED); np.random.seed(SEED)
14
 
15
+ print(f"πŸš€ System Info:")
16
+ print(f"Device: {DEVICE}")
17
+ print(f"NumPy: {np.__version__}")
18
+ print(f"PyTorch: {torch.__version__}")
19
+ if torch.cuda.is_available():
20
+ print(f"CUDA: {torch.version.cuda}")
21
 
22
  class ConversationalAI:
23
  def __init__(self):
 
32
  self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
33
  "nvidia/parakeet-rnnt-1.1b"
34
  ).to(DEVICE).eval()
35
+ print("βœ… Parakeet ASR loaded")
36
  except Exception as e:
37
+ print(f"⚠️ Parakeet failed: {e}")
38
+ print("πŸ”„ Loading Whisper fallback...")
39
  self.asr_pipeline = pipeline(
40
  "automatic-speech-recognition",
41
  model="openai/whisper-base.en",
42
  device=0 if DEVICE == "cuda" else -1
43
  )
44
+ print("βœ… Whisper ASR loaded")
45
 
46
+ # 2. SER: Emotion classifier (simplified for demo)
47
  print("🎭 Setting up emotion recognition...")
 
48
  X_demo = np.random.rand(100, 128)
49
+ y_demo = np.random.randint(0, 5, 100) # 5 emotions: neutral, happy, sad, angry, surprised
50
  self.ser_clf = LogisticRegression().fit(X_demo, y_demo)
51
+ self.emotion_labels = ["neutral", "happy", "sad", "angry", "surprised"]
52
+ print("βœ… SER model ready")
53
 
54
+ # 3. LLM: Conversational model
55
+ print("🧠 Loading LLM...")
56
  bnb_cfg = BitsAndBytesConfig(
57
  load_in_4bit=True,
58
+ bnb_4bit_compute_dtype=torch.float16,
59
+ bnb_4bit_use_double_quant=True,
60
+ bnb_4bit_quant_type="nf4"
61
  )
62
 
63
  model_name = "microsoft/DialoGPT-medium"
 
68
  model_name,
69
  quantization_config=bnb_cfg,
70
  device_map="auto",
71
+ torch_dtype=torch.float16,
72
+ low_cpu_mem_usage=True
73
  )
74
+ print("βœ… LLM loaded")
75
 
76
+ # 4. TTS: Text-to-Speech
77
+ print("πŸ—£οΈ Loading TTS...")
78
  try:
79
  self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE)
80
+ print("βœ… TTS loaded")
81
  except Exception as e:
82
  print(f"⚠️ TTS error: {e}")
83
  self.tts = None
 
85
  # Memory cleanup
86
  if DEVICE == "cuda":
87
  torch.cuda.empty_cache()
88
+ gc.collect()
89
 
90
  def transcribe(self, audio):
91
+ """Convert speech to text"""
92
  try:
93
  if hasattr(self, 'asr_model'):
94
+ # Use Parakeet
95
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
96
  sf.write(temp_file.name, audio[1], audio[0])
97
  transcription = self.asr_model.transcribe([temp_file.name])[0]
98
  os.unlink(temp_file.name)
99
  return transcription.text if hasattr(transcription, 'text') else str(transcription)
100
  else:
101
+ # Use Whisper
102
  return self.asr_pipeline({"sampling_rate": audio[0], "raw": audio[1]})["text"]
103
  except Exception as e:
104
  print(f"ASR Error: {e}")
105
  return "Sorry, I couldn't understand the audio."
106
 
107
  def predict_emotion(self):
108
+ """Predict emotion from audio (simplified demo)"""
109
+ emotion_idx = self.ser_clf.predict(np.random.rand(1, 128))[0]
110
+ return self.emotion_labels[emotion_idx]
111
 
112
+ def generate_response(self, text, emotion):
113
+ """Generate conversational response"""
114
  try:
115
+ # Create emotion-aware prompt
116
+ prompt = f"Human: {text}\nAssistant (feeling {emotion}):"
117
+
118
  inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True).to(DEVICE)
119
 
120
  with torch.no_grad():
 
123
  max_length=inputs.shape[1] + 100,
124
  temperature=TEMPERATURE,
125
  do_sample=True,
126
+ pad_token_id=self.tokenizer.eos_token_id,
127
+ no_repeat_ngram_size=2,
128
+ top_p=0.9
129
  )
130
 
131
  response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
132
+ response = response.split("Human:")[0].strip()
133
+
134
+ return response if response else "I understand. Please tell me more."
135
  except Exception as e:
136
  print(f"LLM Error: {e}")
137
  return "I'm having trouble processing that. Could you please rephrase?"
138
 
139
  def synthesize(self, text):
140
+ """Convert text to speech"""
141
  try:
142
  if self.tts:
143
  wav = self.tts.tts(text=text)
144
  if isinstance(wav, list):
145
  wav = np.array(wav, dtype=np.float32)
146
+ # Normalize audio
147
  wav = wav / np.max(np.abs(wav)) if np.max(np.abs(wav)) > 0 else wav
148
  return (SAMPLE_RATE, (wav * 32767).astype(np.int16))
149
  else:
150
+ # Return silence if TTS fails
151
  return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
152
  except Exception as e:
153
  print(f"TTS Error: {e}")
154
  return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
155
 
156
  def process_conversation(self, audio_input, chat_history):
157
+ """Main pipeline: Speech -> Emotion -> LLM -> Speech"""
158
  if audio_input is None:
159
  return chat_history, None, ""
160
 
161
  try:
162
+ # Step 1: Speech to Text
163
  user_text = self.transcribe(audio_input)
164
  if not user_text.strip():
165
  return chat_history, None, "No speech detected."
166
 
167
+ # Step 2: Emotion Recognition
168
+ emotion = self.predict_emotion()
169
+
170
+ # Step 3: Generate Response
171
+ ai_response = self.generate_response(user_text, emotion)
172
+
173
+ # Step 4: Text to Speech
174
  audio_response = self.synthesize(ai_response)
175
 
176
+ # Update chat history
177
  chat_history.append([user_text, ai_response])
178
 
179
+ # Memory cleanup
180
  if DEVICE == "cuda":
181
  torch.cuda.empty_cache()
182
+ gc.collect()
183
 
184
+ return chat_history, audio_response, f"You said: {user_text} (detected emotion: {emotion})"
185
+
186
  except Exception as e:
187
+ error_msg = f"Error processing conversation: {e}"
188
  print(error_msg)
189
  return chat_history, None, error_msg
190
 
191
  # Initialize AI system
192
+ print("πŸš€ Starting Conversational AI...")
193
  ai_system = ConversationalAI()
194
 
195
+ # Gradio Interface
196
  def create_interface():
197
+ with gr.Blocks(
198
+ title="Emotion-Aware Conversational AI",
199
+ theme=gr.themes.Soft()
200
+ ) as demo:
201
+
202
+ gr.HTML("""
203
+ <div style="text-align: center; margin-bottom: 2rem;">
204
+ <h1>πŸ€– Emotion-Aware Conversational AI</h1>
205
+ <p>Speak naturally and get intelligent responses with emotion recognition</p>
206
+ </div>
207
+ """)
208
 
209
  with gr.Row():
210
+ with gr.Column(scale=2):
211
+ chatbot = gr.Chatbot(
212
+ label="Conversation History",
213
+ height=400,
214
+ show_copy_button=True
215
+ )
216
+
217
+ audio_input = gr.Audio(
218
+ label="🎀 Speak to AI",
219
+ sources=["microphone"],
220
+ type="numpy",
221
+ format="wav"
222
+ )
223
 
224
  with gr.Row():
225
+ submit_btn = gr.Button("πŸ’¬ Process Speech", variant="primary", scale=2)
226
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary", scale=1)
227
 
228
+ with gr.Column(scale=1):
229
+ audio_output = gr.Audio(
230
+ label="πŸ”Š AI Response",
231
+ type="numpy",
232
+ autoplay=True
233
+ )
234
+
235
+ status_display = gr.Textbox(
236
+ label="πŸ“Š Status",
237
+ lines=3,
238
+ interactive=False
239
+ )
240
+
241
+ gr.HTML(f"""
242
+ <div style="padding: 1rem; background: #f0f9ff; border-radius: 0.5rem;">
243
+ <h3>πŸ”§ System Info</h3>
244
+ <p><strong>Device:</strong> {DEVICE.upper()}</p>
245
+ <p><strong>PyTorch:</strong> {torch.__version__}</p>
246
+ <p><strong>Models:</strong> Parakeet + DialoGPT + TTS</p>
247
+ <p><strong>Features:</strong> Emotion Recognition</p>
248
+ </div>
249
+ """)
250
 
251
  def process_audio(audio, history):
252
  return ai_system.process_conversation(audio, history)
253
 
254
+ def clear_conversation():
255
+ if DEVICE == "cuda":
256
+ torch.cuda.empty_cache()
257
+ gc.collect()
258
  return [], None, "Conversation cleared."
259
 
260
+ # Event handlers
261
+ submit_btn.click(
262
+ fn=process_audio,
263
+ inputs=[audio_input, chatbot],
264
+ outputs=[chatbot, audio_output, status_display]
265
+ )
266
+
267
+ clear_btn.click(
268
+ fn=clear_conversation,
269
+ outputs=[chatbot, audio_output, status_display]
270
+ )
271
+
272
+ audio_input.change(
273
+ fn=process_audio,
274
+ inputs=[audio_input, chatbot],
275
+ outputs=[chatbot, audio_output, status_display]
276
+ )
277
 
278
  return demo
279
 
280
+ # Launch application
281
  if __name__ == "__main__":
282
+ print("🌟 Creating interface...")
283
  demo = create_interface()
284
+
285
+ print("πŸš€ Launching application...")
286
+ demo.launch(
287
+ server_name="0.0.0.0",
288
+ server_port=7860,
289
+ share=True,
290
+ show_error=True
291
+ )