Last commit not found
import gradio as gr | |
import torch | |
from TTS.api import TTS | |
import os | |
import spaces | |
import tempfile | |
from pymongo import MongoClient | |
from dotenv import load_dotenv | |
from huggingface_hub import hf_hub_download | |
from transformers import AutoTokenizer | |
# Load environment variables | |
load_dotenv() | |
# Get MongoDB URI and Hugging Face token from .env file | |
mongodb_uri = os.getenv('MONGODB_URI') | |
hf_token = os.getenv('HF_TOKEN') | |
# Connect to MongoDB | |
client = MongoClient(mongodb_uri) | |
db = client['mitra'] | |
voices_collection = db['voices'] | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize TTS model | |
def load_tts_model(): | |
return TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) | |
tts = load_tts_model() | |
# Fetch celebrity voices from MongoDB | |
def get_celebrity_voices(): | |
voices = {} | |
for category in voices_collection.find(): | |
for voice in category['voices']: | |
voices[voice['name']] = f"voices/{voice['name']}.mp3" | |
return voices | |
celebrity_voices = get_celebrity_voices() | |
def check_voice_files(): | |
""" | |
Checks if all voice files exist in the Hugging Face repository. | |
Returns a message listing missing files or confirming all files are present. | |
""" | |
missing = [] | |
for voice, path in celebrity_voices.items(): | |
try: | |
hf_hub_download(repo_id="nikkmitra/clone", filename=path, repo_type="space", token=hf_token) | |
except Exception: | |
missing.append(f"{voice}: {path}") | |
if missing: | |
return "**Missing Voice Files:**\n" + "\n".join(missing) | |
else: | |
return "**All voice files are present.** 🎉" | |
# New function to split text into chunks of 100 tokens using the Hindi tokenizer | |
def split_text_into_chunks(text, max_tokens=100, language="en"): | |
""" | |
Splits the input text into chunks with a maximum of `max_tokens` tokens each. | |
Inserts a newline after each chunk. | |
Uses a specialized tokenizer for Hindi language. | |
""" | |
chunks = [] | |
for i in range(0, len(tokens), max_tokens): | |
chunk = ' '.join(tokens[i:i + max_tokens]) | |
chunks.append(chunk) | |
return '\n'.join(chunks) | |
def tts_generate(text, voice, language): | |
# Check for Hindi language and split text if necessary | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
temp_audio_path = temp_audio.name | |
try: | |
voice_file = hf_hub_download(repo_id="nikkmitra/clone", filename=celebrity_voices[voice], repo_type="space", token=hf_token) | |
except Exception as e: | |
return f"Error downloading voice file: {e}" | |
try: | |
tts.tts_to_file( | |
text=text, | |
speaker_wav=voice_file, | |
language=language, | |
file_path=temp_audio_path | |
) | |
except AssertionError as ae: | |
return f"Error: {ae}" | |
except Exception as e: | |
return f"An unexpected error occurred: {e}" | |
return temp_audio_path | |
def clone_voice(text, audio_file, language): | |
print("cloning") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
temp_audio_path = temp_audio.name | |
try: | |
tts.tts_to_file( | |
text=text, | |
speaker_wav=audio_file, | |
language=language, | |
file_path=temp_audio_path | |
) | |
except AssertionError as ae: | |
return f"Error: {ae}" | |
except Exception as e: | |
return f"An unexpected error occurred: {e}" | |
return temp_audio_path | |
# Define Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Advanced Voice Synthesis") | |
# Display voice files status | |
voice_status = check_voice_files() | |
gr.Markdown(voice_status) | |
with gr.Tabs(): | |
with gr.TabItem("TTS"): | |
with gr.Row(): | |
tts_text = gr.Textbox(label="Text to speak") | |
tts_voice = gr.Dropdown(choices=list(celebrity_voices.keys()), label="Celebrity Voice") | |
tts_language = gr.Dropdown(["en", "es", "fr", "de", "it", "ar","hi"], label="Language", value="en") | |
tts_generate_btn = gr.Button("Generate") | |
tts_output = gr.Audio(label="Generated Audio") | |
tts_generate_btn.click( | |
tts_generate, | |
inputs=[tts_text, tts_voice, tts_language], | |
outputs=tts_output | |
) | |
with gr.TabItem("Clone Voice"): | |
with gr.Row(): | |
clone_text = gr.Textbox(label="Text to speak") | |
clone_audio = gr.Audio(label="Voice reference audio file", type="filepath") | |
clone_language = gr.Dropdown(["en", "es", "fr", "de", "it", "ar", "hi"], label="Language", value="en") | |
clone_generate_btn = gr.Button("Generate") | |
clone_output = gr.Audio(label="Generated Audio") | |
clone_generate_btn.click( | |
clone_voice, | |
inputs=[clone_text, clone_audio, clone_language], | |
outputs=clone_output | |
) | |
# Launch the interface | |
demo.launch() | |
# Clean up temporary files (this will run after the Gradio server is closed) | |
for file in os.listdir(): | |
if file.endswith('.wav') and file.startswith('tmp'): | |
os.remove(file) |