File size: 6,030 Bytes
ddf4b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import torch
import whisper
import streamlit as st
from groq import Groq
from TTS.api import TTS
from tempfile import NamedTemporaryFile
from streamlit_webrtc import webrtc_streamer, WebRtcMode, ClientSettings
import av

# LLM Response Function
def get_llm_response(api_key, user_input):
    client = Groq(api_key=api_key)
    prompt = (
        "IMPORTANT: You are an AI assistant that MUST provide responses in 25 words or less.\n"
        "CRITICAL RULES:\n"
        "1. NEVER exceed 25 words unless absolutely necessary.\n"
        "2. Always give a complete sentence with full context.\n"
        "3. Answer directly and precisely.\n"
        "4. Use clear, simple language.\n"
        "5. Maintain a polite, professional tone.\n"
        "6. NO lists, bullet points, or multiple paragraphs.\n"
        "7. NEVER apologize for brevity - embrace it.\n"
        "Your response will be converted to speech. Maximum 25 words."
    )
    
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": prompt},
            {"role": "user", "content": user_input}
        ],
        model="llama3-8b-8192",
        temperature=0.5,
        top_p=1,
        stream=False,
    )
    return chat_completion.choices[0].message.content

# Transcribe Audio
def transcribe_audio(audio_path, model_size="base"):
    model = whisper.load_model(model_size)
    result = model.transcribe(audio_path)
    return result["text"]

# Generate Speech
def generate_speech(text, output_file, speaker_wav, language="en", use_gpu=True):
    if not os.path.exists(speaker_wav):
        raise FileNotFoundError("Reference audio file not found. Please upload or record a valid audio.")
    
    tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=use_gpu)
    tts.tts_to_file(
        text=text,
        file_path=output_file,
        speaker_wav=speaker_wav,
        language=language,
    )

# Audio Frame Processing
class AudioProcessor:
    def __init__(self):
        self.audio_frames = []

    def recv(self, frame):
        self.audio_frames.append(frame.to_ndarray().tobytes())
        return frame

    def save_audio(self, file_path):
        with open(file_path, "wb") as f:
            for frame in self.audio_frames:
                f.write(frame)
        return file_path

# Streamlit App
def main():
    st.set_page_config(page_title="Vocal AI", layout="wide")
    st.sidebar.title("Vocal-AI Settings")

    # User option for reference audio (Record or Upload)
    ref_audio_choice = st.sidebar.radio("Reference Audio", ("Upload", "Record"))

    ref_audio_path = None
    reference_audio_processor = None

    if ref_audio_choice == "Upload":
        reference_audio = st.sidebar.file_uploader("Upload Reference Audio", type=["wav", "mp3", "ogg"])
        if reference_audio:
            with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio:
                temp_ref_audio.write(reference_audio.read())
                ref_audio_path = temp_ref_audio.name
    else:
        st.sidebar.write("Record your reference audio:")
        reference_audio_processor = AudioProcessor()
        webrtc_streamer(
            key="ref_audio",
            mode=WebRtcMode.SENDRECV,
            client_settings=ClientSettings(rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}),
            audio_receiver_size=1024,
            video_processor_factory=None,
            audio_processor_factory=lambda: reference_audio_processor,
        )

    st.title("Welcome to VocaL AI")
    st.write("### How to Use")
    st.write("1. Upload or record a reference audio file.")
    st.write("2. Choose between text or audio input.")
    st.write("3. If audio input is selected, record and submit your audio.")
    st.write("4. Click 'Generate Speech' to hear the AI response in your cloned voice.")

    # User Input (Text or Audio)
    input_type = st.radio("Choose Input Type", ("Text", "Audio"))
    user_input = None
    user_audio_processor = None

    if input_type == "Text":
        user_input = st.text_area("Enter your text here")
    else:
        st.write("Record your voice:")
        user_audio_processor = AudioProcessor()
        webrtc_streamer(
            key="user_audio",
            mode=WebRtcMode.SENDRECV,
            client_settings=ClientSettings(rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}),
            audio_receiver_size=1024,
            video_processor_factory=None,
            audio_processor_factory=lambda: user_audio_processor,
        )

    if st.button("Generate Speech"):
        # Handle Reference Audio
        if reference_audio_processor:
            with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio:
                reference_audio_processor.save_audio(temp_ref_audio.name)
                ref_audio_path = temp_ref_audio.name

        if not ref_audio_path:
            st.error("Please upload or record reference audio.")
            return

        # Handle User Input
        if input_type == "Audio":
            if user_audio_processor:
                with NamedTemporaryFile(delete=False, suffix=".wav") as temp_user_audio:
                    user_audio_processor.save_audio(temp_user_audio.name)
                    user_input = transcribe_audio(temp_user_audio.name)
                    os.unlink(temp_user_audio.name)

        if not user_input:
            st.error("Please enter text or record audio.")
            return

        # Get AI Response
        api_key = st.secrets["GROQ_API_KEY"]
        response_text = get_llm_response(api_key, user_input)

        # Generate Speech
        output_audio_path = "output_speech.wav"
        try:
            generate_speech(response_text, output_audio_path, ref_audio_path)
            os.unlink(ref_audio_path)
            st.audio(output_audio_path, format="audio/wav")
        except FileNotFoundError as e:
            st.error(str(e))

if __name__ == "__main__":
    main()