sudoping01's picture
Update app.py
fa84412 verified
raw
history blame
9.96 kB
import gradio as gr
import numpy as np
import os
import spaces
import logging
from huggingface_hub import login
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Global variables for model and speakers
tts_model = None
speakers_dict = None
model_initialized = False
@spaces.GPU()
def initialize_model():
"""Initialize the TTS model and speakers - called once with GPU context"""
global tts_model, speakers_dict, model_initialized
if not model_initialized:
logger.info("Initializing Bambara TTS model...")
try:
# Import inside GPU context to avoid CUDA initialization errors
from maliba_ai.tts.inference import BambaraTTSInference
from maliba_ai.config.speakers import Adame, Moussa, Bourama, Modibo, Seydou
# Initialize model
tts_model = BambaraTTSInference()
# Initialize speakers
speakers_dict = {
"Adame": Adame,
"Moussa": Moussa,
"Bourama": Bourama,
"Modibo": Modibo,
"Seydou": Seydou
}
model_initialized = True
logger.info("Model initialized successfully!")
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise e
return tts_model, speakers_dict
def validate_inputs(text, temperature, top_k, top_p, max_tokens):
"""Validate user inputs"""
if not text or not text.strip():
return False, "Please enter some Bambara text."
if not (0.001 <= temperature <= 2.0):
return False, "Temperature must be between 0.001 and 2.0"
if not (1 <= top_k <= 100):
return False, "Top-K must be between 1 and 100"
if not (0.1 <= top_p <= 1.0):
return False, "Top-P must be between 0.1 and 1.0"
return True, ""
@spaces.GPU()
def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p, max_tokens):
"""Generate speech using the pre-loaded model"""
if not text.strip():
return None, "Please enter some Bambara text."
try:
# Get the initialized model and speakers
tts, speakers = initialize_model()
speaker = speakers[speaker_name]
if use_advanced:
is_valid, error_msg = validate_inputs(text, temperature, top_k, top_p, max_tokens)
if not is_valid:
return None, f"❌ {error_msg}"
waveform = tts.generate_speech(
text=text.strip(),
speaker_id=speaker,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
max_new_audio_tokens=int(max_tokens)
)
else:
waveform = tts.generate_speech(
text=text.strip(),
speaker_id=speaker
)
if waveform.size == 0:
return None, "Failed to generate audio. Please try again."
sample_rate = 16000
return (sample_rate, waveform), f"✅ Audio generated successfully"
except Exception as e:
logger.error(f"Speech generation failed: {e}")
return None, f"❌ Error: {str(e)}"
# Define speaker names for UI
SPEAKER_NAMES = ["Adame", "Moussa", "Bourama", "Modibo", "Seydou"]
examples = [
["Aw ni ce", "Adame"],
["I ni ce", "Moussa"],
["Aw ni tile", "Bourama"],
["I ka kene wa?", "Modibo"],
["Ala ka Mali suma", "Adame"],
["sigikafɔ kɔnɔ jamanaw ni ɲɔgɔn cɛ, olu ye a haminankow ye, wa o ko ninnu ka kan ka kɛ sariya ani tilennenya kɔnɔ", "Seydou"],
["Aw ni ce. Ne tɔgɔ ye Kaya Magan. Aw Sanbe Sanbe.", "Moussa"],
["An dɔlakelen bɛ masike bilenman don ka tɔw gɛn.", "Bourama"],
["Aw ni ce. Seidu bɛ aw fo wa aw ka yafa a ma, ka da a kan tuma dɔw la kow ka can.", "Modibo"],
]
def build_interface():
"""Build the Gradio interface for Bambara TTS"""
with gr.Blocks(title="Bambara TTS - EXPERIMENTAL", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎤 Bambara Text-to-Speech ⚠️ EXPERIMENTAL
Convert Bambara text to speech using AI. This model is currently experimental.
**Bambara** is spoken by millions of people in Mali and West Africa.
⚡ **Note**: Model loads automatically on first use and stays loaded for optimal performance.
""")
with gr.Row():
with gr.Column(scale=2):
# Input section
text_input = gr.Textbox(
label="📝 Bambara Text",
placeholder="Type your Bambara text here...",
lines=3,
max_lines=6,
value="Aw ni ce"
)
speaker_dropdown = gr.Dropdown(
choices=SPEAKER_NAMES,
value="Adame",
label="🗣️ Speaker Voice"
)
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
with gr.Column(scale=1):
use_advanced = gr.Checkbox(
label="⚙️ Use Advanced Settings",
value=False,
info="Enable to customize generation parameters"
)
with gr.Group(visible=False) as advanced_group:
gr.Markdown("**Advanced Parameters:**")
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature",
info="Higher = more varied"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=5,
label="Top-K"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P"
)
max_tokens = gr.Slider(
minimum=256,
maximum=4096,
value=2048,
step=256,
label="Max Length"
)
gr.Markdown("### 🔊 Generated Audio")
audio_output = gr.Audio(
label="Generated Speech",
type="numpy",
interactive=False
)
status_output = gr.Textbox(
label="Status",
interactive=False,
show_label=False,
container=False
)
with gr.Accordion("📚 Try These Examples", open=True):
def load_example(text, speaker):
return text, speaker, False, 0.8, 50, 0.9, 2048
gr.Markdown("**Click any example below:**")
for i, (text, speaker) in enumerate(examples):
btn = gr.Button(f"🎯 {text[:30]}{'...' if len(text) > 30 else ''}", size="sm")
btn.click(
fn=lambda t=text, s=speaker: load_example(t, s),
outputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens]
)
# Information section
with gr.Accordion("ℹ️ About", open=False):
gr.Markdown("""
**⚠️ This is an experimental Bambara TTS model.**
- **Model**: Based on SparkTTS architecture with BiCodec
- **Languages**: Bambara (bm)
- **Speakers**: 5 different voice options
- **Sample Rate**: 16kHz
- **Architecture**: Neural codec with semantic and global tokens
## 🚀 How to Use
1. **Enter Text**: Type your Bambara text in the input box
2. **Choose Speaker**: Select from 5 available voice options
3. **Advanced Settings**: Optionally adjust generation parameters
4. **Generate**: Click the generate button to create speech
""")
def toggle_advanced(use_adv):
return gr.Group(visible=use_adv)
use_advanced.change(
fn=toggle_advanced,
inputs=[use_advanced],
outputs=[advanced_group]
)
generate_btn.click(
fn=generate_speech,
inputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens],
outputs=[audio_output, status_output],
show_progress=True
)
text_input.submit(
fn=generate_speech,
inputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens],
outputs=[audio_output, status_output],
show_progress=True
)
return demo
def main():
"""Main function to launch the Gradio interface"""
logger.info("Starting Bambara TTS Gradio interface.")
interface = build_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
logger.info("Gradio interface launched successfully.")
if __name__ == "__main__":
main()