sudoping01's picture
Update app.py
e849c49 verified
raw
history blame
11.5 kB
import os
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["PYTORCH_DISABLE_CUDNN_BENCHMARK"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import gradio as gr
import numpy as np
import spaces
import logging
from huggingface_hub import login
import threading
import time
torch._dynamo.config.disable = True
torch._dynamo.config.suppress_errors = True
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
class ModelSingleton:
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(ModelSingleton, cls).__new__(cls)
cls._instance.initialized = False
cls._instance.tts_model = None
cls._instance.speakers_dict = None
cls._instance.init_lock = threading.RLock()
return cls._instance
@spaces.GPU()
def initialize(self):
"""Thread-safe initialization with singleton pattern"""
if self.initialized:
logger.info("Model already initialized, skipping...")
return self.tts_model, self.speakers_dict
with self.init_lock:
# Double-check pattern
if self.initialized:
logger.info("Model already initialized (double-check), skipping...")
return self.tts_model, self.speakers_dict
logger.info("Initializing Bambara TTS model...")
start_time = time.time()
try:
from maliba_ai.tts.inference import BambaraTTSInference
from maliba_ai.config.speakers import Adame, Moussa, Bourama, Modibo, Seydou
self.tts_model = BambaraTTSInference()
self.speakers_dict = {
"Adama": Adame,
"Moussa": Moussa,
"Bourama": Bourama,
"Modibo": Modibo,
"Seydou": Seydou
}
self.initialized = True
elapsed = time.time() - start_time
logger.info(f"Model initialized successfully in {elapsed:.2f} seconds!")
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise e
return self.tts_model, self.speakers_dict
def get_model(self):
"""Get the model, initializing if needed"""
if not self.initialized:
return self.initialize()
return self.tts_model, self.speakers_dict
# Global singleton instance
model_singleton = ModelSingleton()
def validate_inputs(text, temperature, top_k, top_p, max_tokens):
if not text or not text.strip():
return False, "Please enter some Bambara text."
if not (0.001 <= temperature <= 2.0):
return False, "Temperature must be between 0.001 and 2.0"
if not (1 <= top_k <= 100):
return False, "Top-K must be between 1 and 100"
if not (0.1 <= top_p <= 1.0):
return False, "Top-P must be between 0.1 and 1.0"
return True, ""
@spaces.GPU()
def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p, max_tokens):
if not text.strip():
return None, "Please enter some Bambara text."
try:
# Get model through singleton
tts, speakers = model_singleton.get_model()
speaker = speakers[speaker_name]
if use_advanced:
is_valid, error_msg = validate_inputs(text, temperature, top_k, top_p, max_tokens)
if not is_valid:
return None, f"❌ {error_msg}"
waveform = tts.generate_speech(
text=text.strip(),
speaker_id=speaker,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
max_new_audio_tokens=int(max_tokens)
)
else:
waveform = tts.generate_speech(
text=text.strip(),
speaker_id=speaker
)
if waveform.size == 0:
return None, "Failed to generate audio. Please try again."
sample_rate = 16000
return (sample_rate, waveform), f"✅ Audio generated successfully"
except Exception as e:
logger.error(f"Speech generation failed: {e}")
return None, f"❌ Error: {str(e)}"
# Preload model on startup (optional - comment out if you prefer lazy loading)
def preload_model():
"""Preload the model when the app starts"""
try:
logger.info("Preloading model...")
model_singleton.initialize()
logger.info("Model preloaded successfully!")
except Exception as e:
logger.error(f"Failed to preload model: {e}")
SPEAKER_NAMES = ["Adame", "Moussa", "Bourama", "Modibo", "Seydou"]
examples = [
["Aw ni ce", "Adame"],
["Mali bɛna diya kɔsɛbɛ, ka a da a kan baara bɛ ka kɛ.", "Moussa"],
["Ne bɛ se ka sɛbɛnni yɛlɛma ka kɛ kuma ye", "Bourama"],
["I ka kɛnɛ wa?", "Modibo"],
["Lakɔli karamɔgɔw tun tɛ ka se ka sɛbɛnni kɛ ka ɲɛ walanba kan wa denmisɛnw tun tɛ ka se ka o sɛbɛnni ninnu ye, kuma tɛ ka u kalan. Denmisɛnw kɛra kunfinw ye.", "Adama"],
["sigikafɔ kɔnɔ jamanaw ni ɲɔgɔn cɛ, olu ye a haminankow ye, wa o ko ninnu ka kan ka kɛ sariya ani tilennenya kɔnɔ.", "Seydou"],
["Aw ni ce. Ne tɔgɔ ye Adama. Awɔ, ne ye maliden de ye. Aw Sanbɛ Sanbɛ. San min tɛ ɲinan ye, an bɛɛ ka jɛ ka o seli ɲɔgɔn fɛ, hɛɛrɛ ni lafiya la. Ala ka Mali suma. Ala ka Mali yiriwa. Ala ka Mali taa ɲɛ. Ala ka an ka seliw caya. Ala ka yafa an bɛɛ ma.", "Moussa"],
["An dɔlakelen bɛ masike bilenman don ka tɔw gɛn.", "Bourama"],
["Aw ni ce. Seidu bɛ aw fo wa aw ka yafa a ma, ka da a kan tuma dɔw la kow ka can.", "Modibo"],
]
def build_interface():
"""Build the Gradio interface for Bambara TTS"""
with gr.Blocks(title="Bambara TTS - EXPERIMENTAL") as demo:
gr.Markdown("""
# 🎤 Bambara Text-to-Speech ⚠️ EXPERIMENTAL
**Powered by MALIBA-AI**
Convert Bambara text to speech. This model is currently experimental.
**Bambara** is spoken by millions of people in Mali and West Africa.
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="📝 Bambara Text",
placeholder="Type your Bambara text here...",
lines=3,
max_lines=10,
value="I ni ce"
)
speaker_dropdown = gr.Dropdown(
choices=SPEAKER_NAMES,
value="Adame",
label="🗣️ Speaker Voice"
)
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
with gr.Column(scale=1):
use_advanced = gr.Checkbox(
label="⚙️ Use Advanced Settings",
value=False,
info="Enable to customize generation parameters"
)
with gr.Group(visible=False) as advanced_group:
gr.Markdown("**Advanced Parameters:**")
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature",
info="Higher = more varied"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=5,
label="Top-K"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P"
)
max_tokens = gr.Slider(
minimum=256,
maximum=4096,
value=2048,
step=256,
label="Max Length"
)
gr.Markdown("### 🔊 Generated Audio")
audio_output = gr.Audio(
label="Generated Speech",
type="numpy",
interactive=False
)
status_output = gr.Textbox(
label="Status",
interactive=False,
show_label=False,
container=False
)
with gr.Accordion("Try These Examples", open=True):
def load_example(text, speaker):
return text, speaker, False, 0.8, 50, 0.9, 2048
gr.Markdown("**Click any example below:**")
for i, (text, speaker) in enumerate(examples):
btn = gr.Button(f"{text[:30]}{'...' if len(text) > 30 else ''}", size="sm")
btn.click(
fn=lambda t=text, s=speaker: load_example(t, s),
outputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens]
)
with gr.Accordion("About", open=False):
gr.Markdown("""
**⚠️ This is an experimental Bambara TTS model.**
- **Languages**: Bambara (bm)
- **Speakers**: 5 different voice options
- **Sample Rate**: 16kHz
**Status**: Model loads once and reuses for all requests
""")
def toggle_advanced(use_adv):
return gr.Group(visible=use_adv)
use_advanced.change(
fn=toggle_advanced,
inputs=[use_advanced],
outputs=[advanced_group]
)
generate_btn.click(
fn=generate_speech,
inputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens],
outputs=[audio_output, status_output],
show_progress=True
)
text_input.submit(
fn=generate_speech,
inputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens],
outputs=[audio_output, status_output],
show_progress=True
)
return demo
def main():
"""Main function to launch the Gradio interface"""
logger.info("Starting Bambara TTS Gradio interface.")
preload_model()
interface = build_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
logger.info("Gradio interface launched successfully.")
if __name__ == "__main__":
main()