Spaces:
Sleeping
Sleeping
File size: 5,132 Bytes
8afdb95 0bb7775 8afdb95 9e5a4f6 0bb7775 9e5a4f6 8afdb95 9e5a4f6 8afdb95 9e5a4f6 586d983 9e5a4f6 586d983 9e5a4f6 586d983 9e5a4f6 586d983 9e5a4f6 02d76aa 9e5a4f6 02d76aa 9e5a4f6 8afdb95 9e5a4f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import whisper
from gtts import gTTS
from groq import Groq
import os
import numpy as np
import soundfile as sf
import logging
# Configure logging
logging.basicConfig(level=logging.DEBUG)
# Initialize the Groq API key from environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise RuntimeError("GROQ_API_KEY environment variable not set.")
# Initialize Whisper model (No API key required)
try:
whisper_model = whisper.load_model("base")
logging.info("Whisper model loaded successfully.")
except Exception as e:
raise RuntimeError(f"Error loading Whisper model: {e}")
# Initialize Groq client (API key required for Groq API)
try:
client = Groq(
api_key=GROQ_API_KEY # Use the API key from the environment variable
)
logging.info("Groq client initialized successfully.")
except Exception as e:
raise RuntimeError(f"Error initializing Groq client: {e}")
# Function to transcribe audio using Whisper
def transcribe_audio(audio):
try:
# Load audio file with soundfile
logging.debug(f"Loading audio file: {audio}")
audio_data, sample_rate = sf.read(audio, dtype='float32') # Ensure dtype is float32
logging.debug(f"Audio loaded with sample rate: {sample_rate}, data shape: {audio_data.shape}")
# Whisper expects a specific sample rate
if sample_rate != 16000:
logging.debug(f"Resampling audio from {sample_rate} to 16000 Hz")
# Resample audio data to 16000 Hz
num_samples = int(len(audio_data) * (16000 / sample_rate))
audio_data_resampled = np.interp(np.linspace(0, len(audio_data), num_samples),
np.arange(len(audio_data)),
audio_data)
audio_data = audio_data_resampled.astype(np.float32) # Ensure dtype is float32
sample_rate = 16000
# Perform the transcription
result = whisper_model.transcribe(audio_data)
logging.debug(f"Transcription result: {result['text']}")
return result['text']
except Exception as e:
logging.error(f"Error during transcription: {e}")
return f"Error during transcription: {e}"
# Function to get response from LLaMA model using Groq API
def get_response(text):
try:
logging.debug(f"Sending request to Groq API with text: {text}")
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": text, # Using the transcribed text as input
}
],
model="llama3-8b-8192", # Ensure the correct model is used
)
# Extract and return the model's response
response_text = chat_completion.choices[0].message.content
logging.debug(f"Received response from Groq API: {response_text}")
return response_text
except Exception as e:
logging.error(f"Error during model response generation: {e}")
return f"Error during model response generation: {e}"
# Function to convert text to speech using gTTS
def text_to_speech(text):
try:
tts = gTTS(text)
tts.save("response.mp3")
logging.debug("Text-to-speech conversion completed successfully.")
return "response.mp3"
except Exception as e:
logging.error(f"Error during text-to-speech conversion: {e}")
return f"Error during text-to-speech conversion: {e}"
# Combined function for Gradio
def chatbot(audio):
try:
# Step 1: Transcribe the audio input using Whisper
user_input = transcribe_audio(audio)
# Check if transcription returned an error
if "Error" in user_input:
return user_input, None
logging.debug(f"Transcribed text: {user_input}")
# Step 2: Get response from the LLaMA model using Groq API
response_text = get_response(user_input)
# Check if the response generation returned an error
if "Error" in response_text:
return response_text, None
logging.debug(f"Response text: {response_text}")
# Step 3: Convert the response text to speech using gTTS
response_audio = text_to_speech(response_text)
# Check if the text-to-speech conversion returned an error
if "Error" in response_audio:
return response_audio, None
# Step 4: Return the response text and response audio file
return response_text, response_audio
except Exception as e:
logging.error(f"Unexpected error occurred: {e}")
return f"Unexpected error occurred: {e}", None
# Gradio Interface
iface = gr.Interface(
fn=chatbot,
inputs=gr.Audio(type="filepath"),
outputs=[gr.Textbox(label="Response Text"), gr.Audio(label="Response Audio")],
live=True,
title="Voice-to-Voice Chatbot",
description="Speak to the bot, and it will respond with voice.",
)
try:
iface.launch()
except Exception as e:
logging.error(f"Error launching Gradio interface: {e}")
|