Spaces:
Sleeping
Sleeping
import gradio as gr | |
import asyncio | |
import edge_tts | |
import speech_recognition as sr | |
from pydub import AudioSegment | |
from pydub.playback import play | |
import os | |
from huggingface_hub import InferenceClient | |
import whisper | |
import torch | |
from io import BytesIO | |
import tempfile | |
# Get the Hugging Face token from environment variable | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HF_TOKEN environment variable is not set") | |
# Initialize the Hugging Face Inference Client | |
client = InferenceClient( | |
"mistralai/Mistral-Nemo-Instruct-2407", | |
token=hf_token | |
) | |
# Load the Whisper model | |
whisper_model = whisper.load_model("tiny.en", device='cpu') | |
# Initialize an empty chat history | |
chat_history = [] | |
async def text_to_speech_stream(text): | |
"""Convert text to speech using edge_tts and return the audio file path.""" | |
communicate = edge_tts.Communicate(text, "en-US-AvaMultilingualNeural") | |
audio_data = b"" | |
async for chunk in communicate.stream(): | |
if chunk["type"] == "audio": | |
audio_data += chunk["data"] | |
# Save the audio data to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
temp_file.write(audio_data) | |
return temp_file.name | |
def whisper_speech_to_text(audio): | |
"""Convert speech to text using Whisper model.""" | |
try: | |
result = whisper_model.transcribe(audio) | |
text = result['text'] | |
return text | |
except Exception as e: | |
print(f"Whisper Error: {e}") | |
return None | |
finally: | |
# Clear CUDA cache | |
torch.cuda.empty_cache() | |
async def chat_with_ai(message, history): | |
global chat_history | |
# Add user message to chat history | |
chat_history.append({"role": "user", "content": message}) | |
try: | |
# Send chat completion request | |
response = client.chat_completion( | |
messages=[{"role": "system", "content": "You are a helpful voice assistant. Provide concise and clear responses to user queries."}] + chat_history, | |
max_tokens=800, | |
temperature=0.7 | |
) | |
response_text = response.choices[0].message['content'] | |
# Add assistant's response to chat history | |
chat_history.append({"role": "assistant", "content": response_text}) | |
# Generate speech for the response | |
audio_path = await text_to_speech_stream(response_text) | |
return response_text, audio_path | |
except Exception as e: | |
print(f"Error: {e}") | |
return str(e), None | |
def transcribe_and_chat(audio): | |
# Transcribe audio to text | |
text = whisper_speech_to_text(audio) | |
if text is None: | |
return "Sorry, I couldn't understand the audio.", None | |
# Chat with AI using the transcribed text | |
response, audio_path = asyncio.run(chat_with_ai(text, [])) | |
return response, audio_path | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# AI Voice Assistant") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(type="filepath", label="Speak here", interactive=False) | |
text_input = gr.Textbox(label="Or type your message here") | |
with gr.Column(): | |
chat_output = gr.Textbox(label="AI Response") | |
audio_output = gr.Audio(label="AI Voice Response", interactive=False) | |
audio_button = gr.Button("Send Audio") | |
text_button = gr.Button("Send Text") | |
# Add custom JavaScript to handle spacebar press and play audio automatically | |
demo.append(gr.HTML(""" | |
<script> | |
document.addEventListener('keydown', function(event) { | |
if (event.code === 'Space') { | |
document.querySelector('input[type="file"]').click(); | |
} | |
}); | |
document.addEventListener('gradioAudioLoaded', function(event) { | |
var audioElement = document.querySelector('audio'); | |
if (audioElement) { | |
audioElement.play(); | |
} | |
}); | |
</script> | |
""")) | |
audio_button.click(transcribe_and_chat, inputs=audio_input, outputs=[chat_output, audio_output]) | |
text_button.click(lambda x: asyncio.run(chat_with_ai(x, [])), inputs=text_input, outputs=[chat_output, audio_output]) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |