Waseem771 commited on
Commit
3c84040
·
verified ·
1 Parent(s): dccfaa0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+ import gradio as gr
6
+ import whisper
7
+ from gtts import gTTS
8
+ from groq import Groq
9
+ import os
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import logging
13
+
14
+ # Initialize the Groq API key
15
+ GROQ_API_KEY = "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"
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.DEBUG)
19
+
20
+ # Initialize Whisper model (No API key required)
21
+ try:
22
+ whisper_model = whisper.load_model("base")
23
+ logging.info("Whisper model loaded successfully.")
24
+ except Exception as e:
25
+ raise RuntimeError(f"Error loading Whisper model: {e}")
26
+
27
+ # Initialize Groq client (API key required for Groq API)
28
+ try:
29
+ client = Groq(
30
+ api_key=GROQ_API_KEY # Directly use the API key from the variable
31
+ )
32
+ logging.info("Groq client initialized successfully.")
33
+ except Exception as e:
34
+ raise RuntimeError(f"Error initializing Groq client: {e}")
35
+
36
+ # Function to transcribe audio using Whisper
37
+ def transcribe_audio(audio):
38
+ try:
39
+ # Load audio file with soundfile
40
+ logging.debug(f"Loading audio file: {audio}")
41
+ audio_data, sample_rate = sf.read(audio, dtype='float32') # Ensure dtype is float32
42
+ logging.debug(f"Audio loaded with sample rate: {sample_rate}, data shape: {audio_data.shape}")
43
+
44
+ # Whisper expects a specific sample rate
45
+ if sample_rate != 16000:
46
+ logging.debug(f"Resampling audio from {sample_rate} to 16000 Hz")
47
+ # Resample audio data to 16000 Hz
48
+ num_samples = int(len(audio_data) * (16000 / sample_rate))
49
+ audio_data_resampled = np.interp(np.linspace(0, len(audio_data), num_samples),
50
+ np.arange(len(audio_data)),
51
+ audio_data)
52
+ audio_data = audio_data_resampled.astype(np.float32) # Ensure dtype is float32
53
+ sample_rate = 16000
54
+
55
+ # Perform the transcription
56
+ result = whisper_model.transcribe(audio_data)
57
+ logging.debug(f"Transcription result: {result['text']}")
58
+ return result['text']
59
+ except Exception as e:
60
+ logging.error(f"Error during transcription: {e}")
61
+ return f"Error during transcription: {e}"
62
+
63
+ # Function to get response from LLaMA model using Groq API
64
+ def get_response(text):
65
+ try:
66
+ logging.debug(f"Sending request to Groq API with text: {text}")
67
+ chat_completion = client.chat.completions.create(
68
+ messages=[
69
+ {
70
+ "role": "user",
71
+ "content": text, # Using the transcribed text as input
72
+ }
73
+ ],
74
+ model="llama3-8b-8192", # Ensure the correct model is used
75
+ )
76
+
77
+ # Extract and return the model's response
78
+ response_text = chat_completion.choices[0].message.content
79
+ logging.debug(f"Received response from Groq API: {response_text}")
80
+ return response_text
81
+ except Exception as e:
82
+ logging.error(f"Error during model response generation: {e}")
83
+ return f"Error during model response generation: {e}"
84
+
85
+ # Function to convert text to speech using gTTS
86
+ def text_to_speech(text):
87
+ try:
88
+ tts = gTTS(text)
89
+ tts.save("response.mp3")
90
+ logging.debug("Text-to-speech conversion completed successfully.")
91
+ return "response.mp3"
92
+ except Exception as e:
93
+ logging.error(f"Error during text-to-speech conversion: {e}")
94
+ return f"Error during text-to-speech conversion: {e}"
95
+
96
+ # Combined function for Gradio
97
+ def chatbot(audio):
98
+ try:
99
+ # Step 1: Transcribe the audio input using Whisper
100
+ user_input = transcribe_audio(audio)
101
+
102
+ # Check if transcription returned an error
103
+ if "Error" in user_input:
104
+ return user_input, None
105
+
106
+ logging.debug(f"Transcribed text: {user_input}")
107
+
108
+ # Step 2: Get response from the LLaMA model using Groq API
109
+ response_text = get_response(user_input)
110
+
111
+ # Check if the response generation returned an error
112
+ if "Error" in response_text:
113
+ return response_text, None
114
+
115
+ logging.debug(f"Response text: {response_text}")
116
+
117
+ # Step 3: Convert the response text to speech using gTTS
118
+ response_audio = text_to_speech(response_text)
119
+
120
+ # Check if the text-to-speech conversion returned an error
121
+ if "Error" in response_audio:
122
+ return response_audio, None
123
+
124
+ # Step 4: Return the response text and response audio file
125
+ return response_text, response_audio
126
+
127
+ except Exception as e:
128
+ logging.error(f"Unexpected error occurred: {e}")
129
+ return f"Unexpected error occurred: {e}", None
130
+
131
+ # Gradio Interface
132
+ iface = gr.Interface(
133
+ fn=chatbot,
134
+ inputs=gr.Audio(type="filepath"),
135
+ outputs=[gr.Textbox(label="Response Text"), gr.Audio(label="Response Audio")],
136
+ live=True,
137
+ title="Voice-to-Voice Chatbot",
138
+ description="Speak to the bot, and it will respond with voice.",
139
+ )
140
+
141
+ try:
142
+ iface.launch()
143
+ except Exception as e:
144
+ logging.error(f"Error launching Gradio interface: {e}")