Spaces:
Sleeping
Sleeping
import gradio as gr | |
import whisper | |
from gtts import gTTS | |
from groq import Groq | |
import os | |
import numpy as np | |
import soundfile as sf | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG) | |
# Initialize the Groq API key from environment variables | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
if not GROQ_API_KEY: | |
raise RuntimeError("GROQ_API_KEY environment variable not set.") | |
# Initialize Whisper model (No API key required) | |
try: | |
whisper_model = whisper.load_model("base") | |
logging.info("Whisper model loaded successfully.") | |
except Exception as e: | |
raise RuntimeError(f"Error loading Whisper model: {e}") | |
# Initialize Groq client (API key required for Groq API) | |
try: | |
client = Groq( | |
api_key=GROQ_API_KEY # Use the API key from the environment variable | |
) | |
logging.info("Groq client initialized successfully.") | |
except Exception as e: | |
raise RuntimeError(f"Error initializing Groq client: {e}") | |
# Function to transcribe audio using Whisper | |
def transcribe_audio(audio): | |
try: | |
# Load audio file with soundfile | |
logging.debug(f"Loading audio file: {audio}") | |
audio_data, sample_rate = sf.read(audio, dtype='float32') # Ensure dtype is float32 | |
logging.debug(f"Audio loaded with sample rate: {sample_rate}, data shape: {audio_data.shape}") | |
# Whisper expects a specific sample rate | |
if sample_rate != 16000: | |
logging.debug(f"Resampling audio from {sample_rate} to 16000 Hz") | |
# Resample audio data to 16000 Hz | |
num_samples = int(len(audio_data) * (16000 / sample_rate)) | |
audio_data_resampled = np.interp(np.linspace(0, len(audio_data), num_samples), | |
np.arange(len(audio_data)), | |
audio_data) | |
audio_data = audio_data_resampled.astype(np.float32) # Ensure dtype is float32 | |
sample_rate = 16000 | |
# Perform the transcription | |
result = whisper_model.transcribe(audio_data) | |
logging.debug(f"Transcription result: {result['text']}") | |
return result['text'] | |
except Exception as e: | |
logging.error(f"Error during transcription: {e}") | |
return f"Error during transcription: {e}" | |
# Function to get response from LLaMA model using Groq API | |
def get_response(text): | |
try: | |
logging.debug(f"Sending request to Groq API with text: {text}") | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": text, # Using the transcribed text as input | |
} | |
], | |
model="llama3-8b-8192", # Ensure the correct model is used | |
) | |
# Extract and return the model's response | |
response_text = chat_completion.choices[0].message.content | |
logging.debug(f"Received response from Groq API: {response_text}") | |
return response_text | |
except Exception as e: | |
logging.error(f"Error during model response generation: {e}") | |
return f"Error during model response generation: {e}" | |
# Function to convert text to speech using gTTS | |
def text_to_speech(text): | |
try: | |
tts = gTTS(text) | |
tts.save("response.mp3") | |
logging.debug("Text-to-speech conversion completed successfully.") | |
return "response.mp3" | |
except Exception as e: | |
logging.error(f"Error during text-to-speech conversion: {e}") | |
return f"Error during text-to-speech conversion: {e}" | |
# Combined function for Gradio | |
def chatbot(audio): | |
try: | |
# Step 1: Transcribe the audio input using Whisper | |
user_input = transcribe_audio(audio) | |
# Check if transcription returned an error | |
if "Error" in user_input: | |
return user_input, None | |
logging.debug(f"Transcribed text: {user_input}") | |
# Step 2: Get response from the LLaMA model using Groq API | |
response_text = get_response(user_input) | |
# Check if the response generation returned an error | |
if "Error" in response_text: | |
return response_text, None | |
logging.debug(f"Response text: {response_text}") | |
# Step 3: Convert the response text to speech using gTTS | |
response_audio = text_to_speech(response_text) | |
# Check if the text-to-speech conversion returned an error | |
if "Error" in response_audio: | |
return response_audio, None | |
# Step 4: Return the response text and response audio file | |
return response_text, response_audio | |
except Exception as e: | |
logging.error(f"Unexpected error occurred: {e}") | |
return f"Unexpected error occurred: {e}", None | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=chatbot, | |
inputs=gr.Audio(type="filepath"), | |
outputs=[gr.Textbox(label="Response Text"), gr.Audio(label="Response Audio")], | |
live=True, | |
title="Voice-to-Voice Chatbot", | |
description="Speak to the bot, and it will respond with voice.", | |
) | |
try: | |
iface.launch() | |
except Exception as e: | |
logging.error(f"Error launching Gradio interface: {e}") | |