File size: 4,423 Bytes
1aef621
 
 
 
 
 
 
 
3bcf6d8
1aef621
 
3bcf6d8
1aef621
 
 
 
3bcf6d8
1aef621
 
3bcf6d8
1aef621
 
 
 
46bf33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bcf6d8
46bf33d
 
 
 
3bcf6d8
 
46bf33d
3bcf6d8
46bf33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bcf6d8
46bf33d
1aef621
3bcf6d8
 
 
 
 
 
89b4ae6
3bcf6d8
 
 
aac3f56
3bcf6d8
03c0141
 
 
 
3bcf6d8
 
 
 
 
 
 
aac3f56
 
 
 
03c0141
aac3f56
 
 
3bcf6d8
aac3f56
 
 
 
 
 
 
03c0141
3bcf6d8
 
 
 
03c0141
3bcf6d8
 
1aef621
 
 
3bcf6d8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import asyncio
import edge_tts
import os
from huggingface_hub import InferenceClient
import whisper
import torch
import tempfile


# Get the Hugging Face token from environment variable
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    raise ValueError("HF_TOKEN environment variable is not set")

# Initialize the Hugging Face Inference Client
client = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407", token=hf_token)

# Load the Whisper model
whisper_model = whisper.load_model("tiny.en", device='cuda' if torch.cuda.is_available() else 'cpu')

# Initialize an empty chat history
chat_history = []

async def text_to_speech_stream(text):
    """Convert text to speech using edge_tts and return the audio file path."""
    communicate = edge_tts.Communicate(text, "en-US-AvaMultilingualNeural")
    audio_data = b""

    async for chunk in communicate.stream():
        if chunk["type"] == "audio":
            audio_data += chunk["data"]

    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
        temp_file.write(audio_data)
        return temp_file.name

def whisper_speech_to_text(audio):
    """Convert speech to text using Whisper model."""
    try:
        result = whisper_model.transcribe(audio)
        return result['text']
    except Exception as e:
        print(f"Whisper Error: {e}")
        return None
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

async def chat_with_ai(message):
    global chat_history
    
    chat_history.append({"role": "user", "content": message})
    
    try:
        response = client.chat_completion(
            messages=[{"role": "system", "content": "You are a helpful voice assistant. Provide concise and clear responses to user queries."}] + chat_history,
            max_tokens=800,
            temperature=0.7
        )
        
        response_text = response.choices[0].message['content']
        chat_history.append({"role": "assistant", "content": response_text})
        
        audio_path = await text_to_speech_stream(response_text)
        
        return response_text, audio_path
    except Exception as e:
        print(f"Error: {e}")
        return str(e), None

def transcribe_and_chat(audio):
    text = whisper_speech_to_text(audio)
    if text is None:
        return "Sorry, I couldn't understand the audio.", None
    
    response, audio_path = asyncio.run(chat_with_ai(text))
    return response, audio_path

def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# AI Voice Assistant")

        with gr.Row():
            with gr.Column(scale=1):
                audio_input = gr.Audio(type="filepath", label="Press 'Record' to Speak")

            with gr.Column(scale=1):
                chat_output = gr.Textbox(label="AI Response")
                audio_output = gr.Audio(label="AI Voice Response", autoplay=True)

        def process_audio(audio):
            response, audio_path = transcribe_and_chat(audio)
            return response, audio_path, None  # Return None to clear the audio input

        demo.load(None, js="""
            function() {
                document.querySelector("audio").addEventListener("stop", function() {
                    setTimeout(function() {
                        document.querySelector('button[title="Submit"]').click();
                    }, 500);
                });
                
                function playAssistantAudio() {
                    var audioElements = document.querySelectorAll('audio');
                    if (audioElements.length > 1) {
                        var assistantAudio = audioElements[1];
                        if (assistantAudio) {
                            assistantAudio.play();
                        }
                    }
                }

                document.addEventListener('gradioAudioLoaded', function(event) {
                    playAssistantAudio();
                });

                document.addEventListener('gradioUpdated', function(event) {
                    setTimeout(playAssistantAudio, 100);
                });
            }
        """)

        audio_input.change(process_audio, inputs=[audio_input], outputs=[chat_output, audio_output, audio_input])

    return demo

# Launch the Gradio app
if __name__ == "__main__":
    demo = create_demo()
    demo.launch(server_name="0.0.0.0", server_port=7860)