Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import whisper | |
import streamlit as st | |
from groq import Groq | |
from dotenv import load_dotenv | |
from tempfile import NamedTemporaryFile | |
# Load environment variables | |
load_dotenv() | |
API_KEY = os.getenv("GROQ_API_KEY") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# By using XTTS you agree to CPML license | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
# Import TTS components | |
from TTS.api import TTS | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
from TTS.utils.generic_utils import get_user_data_dir | |
# Download and configure XTTS model | |
print("Downloading Coqui XTTS V2 if not already downloaded") | |
from TTS.utils.manage import ModelManager | |
model_name = "tts_models/multilingual/multi-dataset/xtts_v2" | |
ModelManager().download_model(model_name) | |
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) | |
print("XTTS downloaded") | |
config = XttsConfig() | |
config.load_json(os.path.join(model_path, "config.json")) | |
model = Xtts.init_from_config(config) | |
model.load_checkpoint( | |
config, | |
checkpoint_path=os.path.join(model_path, "model.pth"), | |
vocab_path=os.path.join(model_path, "vocab.json"), | |
eval=True, | |
use_deepspeed=True, | |
) | |
if torch.cuda.is_available(): | |
model.cuda() | |
supported_languages = config.languages | |
# LLM Response Function | |
def get_llm_response(api_key, user_input): | |
if not api_key: | |
return "API key not found. Please set the GROQ_API_KEY environment variable." | |
client = Groq(api_key=api_key) | |
prompt = ( | |
"IMPORTANT: You are an AI assistant that MUST provide responses in 25 words or less.\n" | |
"CRITICAL RULES:\n" | |
"1. NEVER exceed 25 words unless absolutely necessary.\n" | |
"2. Always give a complete sentence with full context.\n" | |
"3. Answer directly and precisely.\n" | |
"4. Use clear, simple language.\n" | |
"5. Maintain a polite, professional tone.\n" | |
"6. NO lists, bullet points, or multiple paragraphs.\n" | |
"7. NEVER apologize for brevity - embrace it.\n" | |
"Your response will be converted to speech. Maximum 25 words." | |
) | |
try: | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": user_input} | |
], | |
model="llama3-8b-8192", | |
temperature=0.5, | |
top_p=1, | |
stream=False, | |
) | |
return chat_completion.choices[0].message.content | |
except Exception as e: | |
return f"Error with LLM: {str(e)}" | |
# Transcribe Audio | |
def transcribe_audio(audio_path, model_size="base"): | |
try: | |
model = whisper.load_model(model_size) | |
result = model.transcribe(audio_path) | |
return result["text"] | |
except Exception as e: | |
return f"Error transcribing audio: {str(e)}" | |
# Generate Speech using the configured XTTS model | |
def generate_speech(text, output_file, speaker_wav, language="en"): | |
if not os.path.exists(speaker_wav): | |
raise FileNotFoundError("Reference audio file not found. Please upload a valid audio.") | |
if language not in supported_languages: | |
st.warning(f"Language {language} is not supported. Defaulting to English.") | |
language = "en" | |
# Use the configured model directly | |
try: | |
import time | |
t_latent = time.time() | |
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( | |
audio_path=speaker_wav, | |
gpt_cond_len=30, | |
gpt_cond_chunk_len=4, | |
max_ref_length=60 | |
) | |
out = model.inference( | |
text, | |
language, | |
gpt_cond_latent, | |
speaker_embedding, | |
repetition_penalty=5.0, | |
temperature=0.75, | |
) | |
# Save the audio to file | |
torch.tensor(out["wav"]).unsqueeze(0).cpu().numpy() | |
import soundfile as sf | |
sf.write(output_file, out["wav"], 24000, 'PCM_24') | |
return True, "Speech generated successfully" | |
except Exception as e: | |
return False, f"Error generating speech: {str(e)}" | |
# Streamlit App | |
def main(): | |
st.set_page_config(page_title="Vocal AI", layout="wide") | |
st.title("VocaL AI - Voice Cloning Assistant") | |
st.write("Clone your voice and interact with an AI assistant that responds in your voice!") | |
st.sidebar.title("Settings") | |
# Language selection | |
language = st.sidebar.selectbox( | |
"Output Language", | |
supported_languages, | |
index=supported_languages.index("en") if "en" in supported_languages else 0 | |
) | |
# TOS agreement | |
agree_tos = st.sidebar.checkbox("I agree to the Coqui Public Model License (CPML)", value=False) | |
import uuid | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Step 1: Provide Reference Voice") | |
reference_audio = st.file_uploader("Upload Reference Audio", type=["wav", "mp3", "ogg"]) | |
ref_audio_path = None | |
if reference_audio: | |
with NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_audio: | |
temp_ref_audio.write(reference_audio.read()) | |
ref_audio_path = temp_ref_audio.name | |
st.audio(ref_audio_path) | |
with col2: | |
st.header("Step 2: Ask Something") | |
# User Input (Text or Audio) | |
input_type = st.radio("Choose Input Type", ("Text", "Upload Audio")) | |
user_input = None | |
if input_type == "Text": | |
user_input = st.text_area("Enter your question or prompt here") | |
else: | |
user_audio = st.file_uploader("Upload your question as audio", type=["wav", "mp3", "ogg"]) | |
if user_audio: | |
with NamedTemporaryFile(delete=False, suffix=".wav") as temp_user_audio: | |
temp_user_audio.write(user_audio.read()) | |
st.audio(temp_user_audio.name) | |
user_input = transcribe_audio(temp_user_audio.name) | |
st.write(f"Transcribed: {user_input}") | |
# Process and generate response | |
if st.button("Generate AI Response in My Voice"): | |
if not agree_tos: | |
st.error("Please agree to the Coqui Public Model License to continue.") | |
return | |
if not ref_audio_path: | |
st.error("Please upload reference audio.") | |
return | |
if not user_input: | |
st.error("Please enter text or upload an audio question.") | |
return | |
with st.spinner("Processing..."): | |
# Get AI Response | |
llm_response = get_llm_response(API_KEY, user_input) | |
st.subheader("AI Response:") | |
st.write(llm_response) | |
# Generate Speech | |
output_audio_path = f"output_speech_{uuid.uuid4()}.wav" | |
success, message = generate_speech( | |
llm_response, | |
output_audio_path, | |
ref_audio_path, | |
language | |
) | |
if success: | |
st.subheader("Listen to the response in your voice:") | |
st.audio(output_audio_path, format="audio/wav") | |
else: | |
st.error(message) | |
if __name__ == "__main__": | |
main() |