File size: 5,132 Bytes
8afdb95
0bb7775
8afdb95
9e5a4f6
 
 
 
 
0bb7775
9e5a4f6
 
8afdb95
9e5a4f6
 
8afdb95
9e5a4f6
 
586d983
9e5a4f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586d983
9e5a4f6
 
 
 
 
 
 
 
 
 
586d983
9e5a4f6
 
 
 
586d983
9e5a4f6
 
02d76aa
9e5a4f6
 
 
 
 
 
 
 
 
 
 
 
02d76aa
 
9e5a4f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8afdb95
9e5a4f6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import whisper
from gtts import gTTS
from groq import Groq
import os
import numpy as np
import soundfile as sf
import logging

# Configure logging
logging.basicConfig(level=logging.DEBUG)

# Initialize the Groq API key from environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

if not GROQ_API_KEY:
    raise RuntimeError("GROQ_API_KEY environment variable not set.")

# Initialize Whisper model (No API key required)
try:
    whisper_model = whisper.load_model("base")
    logging.info("Whisper model loaded successfully.")
except Exception as e:
    raise RuntimeError(f"Error loading Whisper model: {e}")

# Initialize Groq client (API key required for Groq API)
try:
    client = Groq(
        api_key=GROQ_API_KEY  # Use the API key from the environment variable
    )
    logging.info("Groq client initialized successfully.")
except Exception as e:
    raise RuntimeError(f"Error initializing Groq client: {e}")

# Function to transcribe audio using Whisper
def transcribe_audio(audio):
    try:
        # Load audio file with soundfile
        logging.debug(f"Loading audio file: {audio}")
        audio_data, sample_rate = sf.read(audio, dtype='float32')  # Ensure dtype is float32
        logging.debug(f"Audio loaded with sample rate: {sample_rate}, data shape: {audio_data.shape}")

        # Whisper expects a specific sample rate
        if sample_rate != 16000:
            logging.debug(f"Resampling audio from {sample_rate} to 16000 Hz")
            # Resample audio data to 16000 Hz
            num_samples = int(len(audio_data) * (16000 / sample_rate))
            audio_data_resampled = np.interp(np.linspace(0, len(audio_data), num_samples),
                                             np.arange(len(audio_data)),
                                             audio_data)
            audio_data = audio_data_resampled.astype(np.float32)  # Ensure dtype is float32
            sample_rate = 16000

        # Perform the transcription
        result = whisper_model.transcribe(audio_data)
        logging.debug(f"Transcription result: {result['text']}")
        return result['text']
    except Exception as e:
        logging.error(f"Error during transcription: {e}")
        return f"Error during transcription: {e}"

# Function to get response from LLaMA model using Groq API
def get_response(text):
    try:
        logging.debug(f"Sending request to Groq API with text: {text}")
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": text,  # Using the transcribed text as input
                }
            ],
            model="llama3-8b-8192",  # Ensure the correct model is used
        )

        # Extract and return the model's response
        response_text = chat_completion.choices[0].message.content
        logging.debug(f"Received response from Groq API: {response_text}")
        return response_text
    except Exception as e:
        logging.error(f"Error during model response generation: {e}")
        return f"Error during model response generation: {e}"

# Function to convert text to speech using gTTS
def text_to_speech(text):
    try:
        tts = gTTS(text)
        tts.save("response.mp3")
        logging.debug("Text-to-speech conversion completed successfully.")
        return "response.mp3"
    except Exception as e:
        logging.error(f"Error during text-to-speech conversion: {e}")
        return f"Error during text-to-speech conversion: {e}"

# Combined function for Gradio
def chatbot(audio):
    try:
        # Step 1: Transcribe the audio input using Whisper
        user_input = transcribe_audio(audio)

        # Check if transcription returned an error
        if "Error" in user_input:
            return user_input, None

        logging.debug(f"Transcribed text: {user_input}")

        # Step 2: Get response from the LLaMA model using Groq API
        response_text = get_response(user_input)

        # Check if the response generation returned an error
        if "Error" in response_text:
            return response_text, None

        logging.debug(f"Response text: {response_text}")

        # Step 3: Convert the response text to speech using gTTS
        response_audio = text_to_speech(response_text)

        # Check if the text-to-speech conversion returned an error
        if "Error" in response_audio:
            return response_audio, None

        # Step 4: Return the response text and response audio file
        return response_text, response_audio

    except Exception as e:
        logging.error(f"Unexpected error occurred: {e}")
        return f"Unexpected error occurred: {e}", None

# Gradio Interface
iface = gr.Interface(
    fn=chatbot,
    inputs=gr.Audio(type="filepath"),
    outputs=[gr.Textbox(label="Response Text"), gr.Audio(label="Response Audio")],
    live=True,
    title="Voice-to-Voice Chatbot",
    description="Speak to the bot, and it will respond with voice.",
)

try:
    iface.launch()
except Exception as e:
    logging.error(f"Error launching Gradio interface: {e}")