Devakumar868 commited on
Commit
cf427d1
Β·
verified Β·
1 Parent(s): fc9eb64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -53
app.py CHANGED
@@ -1,64 +1,194 @@
1
- import os, torch, numpy as np, soundfile as sf
2
- import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, BitsAndBytesConfig
4
  import nemo.collections.asr as nemo_asr
5
- from TTS.api import TTS
6
- from sklearn.linear_model import LogisticRegression # for emotion prediction
7
  from datasets import load_dataset
 
8
 
9
  # Configuration
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
- SAMPLE_RATE = 22050
12
- SEED = 42
13
  torch.manual_seed(SEED); np.random.seed(SEED)
14
 
15
- # 1. ASR: Parakeet RNNT
16
- asr = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
17
- model_name="nvidia/parakeet-rnnt-1.1b"
18
- ).to(DEVICE); asr.eval()
19
 
20
- # 2. SER: wav2vec2 emotion classifier
21
- ds = load_dataset("patrickvonplaten/emotion_speech", split="train[:10%]") # sample load
22
- features = ds["audio"]
23
- labels = ds["label"]
24
- # placeholder audio feature extraction
25
- X = np.random.rand(len(features), 128); y = np.array(labels)
26
- clf = LogisticRegression().fit(X, y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # 3. NLP: LLaMA-3
29
- bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
30
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-7b")
31
- llm = AutoModelForSeq2SeqLM.from_pretrained(
32
- "meta-llama/Llama-3-7b", quantization_config=bnb_config, device_map="auto"
33
- ).to(DEVICE)
34
 
35
- # 4. Emotion Prediction: SER β†’ mapping
36
- def predict_emotion(audio_path):
37
- return clf.predict(np.random.rand(1,128))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # 5. TTS: Dia 1.6B with emotion conditioning
40
- tts = TTS("nari-labs/Dia-1.6B", progress_bar=False, gpu=torch.cuda.is_available())
41
-
42
- def transcribe(audio):
43
- sf.write("in.wav", audio, SAMPLE_RATE)
44
- return asr.transcribe(["in.wav"])[0].text
45
-
46
- def generate_response(text, emo_tag):
47
- prompt = f"[emotion:{emo_tag}] {text}"
48
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
49
- gen = llm.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
50
- return tokenizer.decode(gen[0], skip_special_tokens=True)
51
-
52
- def synthesize(text, emo_tag):
53
- return tts.tts(text=text, speaker_wav=None, style_wav=None)
54
-
55
- def pipeline_fn(audio):
56
- user_text = transcribe(audio); emo = predict_emotion("in.wav")
57
- bot_text = generate_response(user_text, emo); wav = synthesize(bot_text, emo)
58
- return bot_text, (SAMPLE_RATE, wav)
59
-
60
- iface = gr.Interface(
61
- pipeline_fn, gr.Audio(source="microphone", type="numpy"),
62
- [gr.Textbox(), gr.Audio()], title="Emotion-Aware Conversational AI"
63
- )
64
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os, torch, numpy as np, soundfile as sf, gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
 
3
  import nemo.collections.asr as nemo_asr
4
+ from TTS.api import TTS # Note: using TTS, not coqui_tts
5
+ from sklearn.linear_model import LogisticRegression
6
  from datasets import load_dataset
7
+ import tempfile
8
 
9
  # Configuration
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ SEED = 42; SAMPLE_RATE = 22050; TEMPERATURE = 0.7
 
12
  torch.manual_seed(SEED); np.random.seed(SEED)
13
 
14
+ print(f"Using device: {DEVICE}")
15
+ print(f"NumPy version: {np.__version__}")
16
+ print(f"PyTorch version: {torch.__version__}")
 
17
 
18
+ class ConversationalAI:
19
+ def __init__(self):
20
+ print("πŸ”„ Initializing Conversational AI...")
21
+ self.setup_models()
22
+ print("βœ… All models loaded successfully!")
23
+
24
+ def setup_models(self):
25
+ # 1. ASR: Parakeet RNNT
26
+ print("πŸ“’ Loading ASR model...")
27
+ try:
28
+ self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
29
+ "nvidia/parakeet-rnnt-1.1b"
30
+ ).to(DEVICE).eval()
31
+ print("βœ… ASR model loaded")
32
+ except Exception as e:
33
+ print(f"⚠️ ASR error: {e}")
34
+ # Fallback to Whisper
35
+ self.asr_pipeline = pipeline(
36
+ "automatic-speech-recognition",
37
+ model="openai/whisper-base.en",
38
+ device=0 if DEVICE == "cuda" else -1
39
+ )
40
+
41
+ # 2. SER: Simple emotion classifier (demo)
42
+ print("🎭 Setting up emotion recognition...")
43
+ # Create dummy SER for demo
44
+ X_demo = np.random.rand(100, 128)
45
+ y_demo = np.random.randint(0, 5, 100) # 5 emotion classes
46
+ self.ser_clf = LogisticRegression().fit(X_demo, y_demo)
47
+
48
+ # 3. LLM: Quantized model for conversation
49
+ print("🧠 Loading LLM model...")
50
+ bnb_cfg = BitsAndBytesConfig(
51
+ load_in_4bit=True,
52
+ bnb_4bit_compute_dtype=torch.float16
53
+ )
54
+
55
+ model_name = "microsoft/DialoGPT-medium"
56
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
57
+ self.tokenizer.pad_token = self.tokenizer.eos_token
58
+
59
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
60
+ model_name,
61
+ quantization_config=bnb_cfg,
62
+ device_map="auto",
63
+ torch_dtype=torch.float16
64
+ )
65
+ print("βœ… LLM model loaded")
66
+
67
+ # 4. TTS: Coqui TTS for speech synthesis
68
+ print("πŸ—£οΈ Loading TTS model...")
69
+ try:
70
+ self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE)
71
+ print("βœ… TTS model loaded")
72
+ except Exception as e:
73
+ print(f"⚠️ TTS error: {e}")
74
+ self.tts = None
75
+
76
+ # Memory cleanup
77
+ if DEVICE == "cuda":
78
+ torch.cuda.empty_cache()
79
+
80
+ def transcribe(self, audio):
81
+ try:
82
+ if hasattr(self, 'asr_model'):
83
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
84
+ sf.write(temp_file.name, audio[1], audio[0])
85
+ transcription = self.asr_model.transcribe([temp_file.name])[0]
86
+ os.unlink(temp_file.name)
87
+ return transcription.text if hasattr(transcription, 'text') else str(transcription)
88
+ else:
89
+ return self.asr_pipeline({"sampling_rate": audio[0], "raw": audio[1]})["text"]
90
+ except Exception as e:
91
+ print(f"ASR Error: {e}")
92
+ return "Sorry, I couldn't understand the audio."
93
+
94
+ def predict_emotion(self):
95
+ # Simple emotion prediction (demo)
96
+ return self.ser_clf.predict(np.random.rand(1, 128))[0]
97
+
98
+ def generate_response(self, text, emo):
99
+ try:
100
+ prompt = f"Human: {text}\nAssistant:"
101
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True).to(DEVICE)
102
+
103
+ with torch.no_grad():
104
+ outputs = self.llm_model.generate(
105
+ inputs,
106
+ max_length=inputs.shape[1] + 100,
107
+ temperature=TEMPERATURE,
108
+ do_sample=True,
109
+ pad_token_id=self.tokenizer.eos_token_id
110
+ )
111
+
112
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
113
+ return response.split("Human:")[0].strip() or "I understand. Please tell me more."
114
+ except Exception as e:
115
+ print(f"LLM Error: {e}")
116
+ return "I'm having trouble processing that. Could you please rephrase?"
117
+
118
+ def synthesize(self, text):
119
+ try:
120
+ if self.tts:
121
+ wav = self.tts.tts(text=text)
122
+ if isinstance(wav, list):
123
+ wav = np.array(wav, dtype=np.float32)
124
+ wav = wav / np.max(np.abs(wav)) if np.max(np.abs(wav)) > 0 else wav
125
+ return (SAMPLE_RATE, (wav * 32767).astype(np.int16))
126
+ else:
127
+ return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
128
+ except Exception as e:
129
+ print(f"TTS Error: {e}")
130
+ return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
131
+
132
+ def process_conversation(self, audio_input, chat_history):
133
+ if audio_input is None:
134
+ return chat_history, None, ""
135
+
136
+ try:
137
+ # Pipeline: ASR -> SER -> LLM -> TTS
138
+ user_text = self.transcribe(audio_input)
139
+ if not user_text.strip():
140
+ return chat_history, None, "No speech detected."
141
+
142
+ emo = self.predict_emotion()
143
+ ai_response = self.generate_response(user_text, emo)
144
+ audio_response = self.synthesize(ai_response)
145
+
146
+ chat_history.append([user_text, ai_response])
147
+
148
+ if DEVICE == "cuda":
149
+ torch.cuda.empty_cache()
150
+
151
+ return chat_history, audio_response, f"You said: {user_text}"
152
+ except Exception as e:
153
+ error_msg = f"Error: {e}"
154
+ print(error_msg)
155
+ return chat_history, None, error_msg
156
 
157
+ # Initialize AI system
158
+ print("πŸš€ Starting initialization...")
159
+ ai_system = ConversationalAI()
 
 
 
160
 
161
+ # Gradio interface
162
+ def create_interface():
163
+ with gr.Blocks(title="Emotion-Aware Conversational AI") as demo:
164
+ gr.HTML("<h1>πŸ€– Emotion-Aware Conversational AI</h1>")
165
+
166
+ with gr.Row():
167
+ with gr.Column():
168
+ chatbot = gr.Chatbot(label="Conversation", height=400)
169
+ audio_input = gr.Audio(label="🎀 Speak", sources=["microphone"], type="numpy")
170
+
171
+ with gr.Row():
172
+ submit_btn = gr.Button("πŸ’¬ Process", variant="primary")
173
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
174
+
175
+ with gr.Column():
176
+ audio_output = gr.Audio(label="πŸ”Š AI Response", autoplay=True)
177
+ status = gr.Textbox(label="πŸ“Š Status", lines=3, interactive=False)
178
+
179
+ def process_audio(audio, history):
180
+ return ai_system.process_conversation(audio, history)
181
+
182
+ def clear_chat():
183
+ return [], None, "Conversation cleared."
184
+
185
+ submit_btn.click(process_audio, [audio_input, chatbot], [chatbot, audio_output, status])
186
+ clear_btn.click(clear_chat, outputs=[chatbot, audio_output, status])
187
+ audio_input.change(process_audio, [audio_input, chatbot], [chatbot, audio_output, status])
188
+
189
+ return demo
190
 
191
+ # Launch
192
+ if __name__ == "__main__":
193
+ demo = create_interface()
194
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)