Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import os | |
import spaces | |
import logging | |
from huggingface_hub import login | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(token=hf_token) | |
# Global variables for model and speakers | |
tts_model = None | |
speakers_dict = None | |
model_initialized = False | |
def initialize_model(): | |
"""Initialize the TTS model and speakers - called once with GPU context""" | |
global tts_model, speakers_dict, model_initialized | |
if not model_initialized: | |
logger.info("Initializing Bambara TTS model...") | |
try: | |
# Import inside GPU context to avoid CUDA initialization errors | |
from maliba_ai.tts.inference import BambaraTTSInference | |
from maliba_ai.config.speakers import Adame, Moussa, Bourama, Modibo, Seydou | |
# Initialize model | |
tts_model = BambaraTTSInference() | |
# Initialize speakers | |
speakers_dict = { | |
"Adame": Adame, | |
"Moussa": Moussa, | |
"Bourama": Bourama, | |
"Modibo": Modibo, | |
"Seydou": Seydou | |
} | |
model_initialized = True | |
logger.info("Model initialized successfully!") | |
except Exception as e: | |
logger.error(f"Failed to initialize model: {e}") | |
raise e | |
return tts_model, speakers_dict | |
def validate_inputs(text, temperature, top_k, top_p, max_tokens): | |
"""Validate user inputs""" | |
if not text or not text.strip(): | |
return False, "Please enter some Bambara text." | |
if not (0.001 <= temperature <= 2.0): | |
return False, "Temperature must be between 0.001 and 2.0" | |
if not (1 <= top_k <= 100): | |
return False, "Top-K must be between 1 and 100" | |
if not (0.1 <= top_p <= 1.0): | |
return False, "Top-P must be between 0.1 and 1.0" | |
return True, "" | |
def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p, max_tokens): | |
"""Generate speech using the pre-loaded model""" | |
if not text.strip(): | |
return None, "Please enter some Bambara text." | |
try: | |
# Get the initialized model and speakers | |
tts, speakers = initialize_model() | |
speaker = speakers[speaker_name] | |
if use_advanced: | |
is_valid, error_msg = validate_inputs(text, temperature, top_k, top_p, max_tokens) | |
if not is_valid: | |
return None, f"❌ {error_msg}" | |
waveform = tts.generate_speech( | |
text=text.strip(), | |
speaker_id=speaker, | |
temperature=temperature, | |
top_k=int(top_k), | |
top_p=top_p, | |
max_new_audio_tokens=int(max_tokens) | |
) | |
else: | |
waveform = tts.generate_speech( | |
text=text.strip(), | |
speaker_id=speaker | |
) | |
if waveform.size == 0: | |
return None, "Failed to generate audio. Please try again." | |
sample_rate = 16000 | |
return (sample_rate, waveform), f"✅ Audio generated successfully" | |
except Exception as e: | |
logger.error(f"Speech generation failed: {e}") | |
return None, f"❌ Error: {str(e)}" | |
# Define speaker names for UI | |
SPEAKER_NAMES = ["Adame", "Moussa", "Bourama", "Modibo", "Seydou"] | |
examples = [ | |
["Aw ni ce", "Adame"], | |
["I ni ce", "Moussa"], | |
["Aw ni tile", "Bourama"], | |
["I ka kene wa?", "Modibo"], | |
["Ala ka Mali suma", "Adame"], | |
["sigikafɔ kɔnɔ jamanaw ni ɲɔgɔn cɛ, olu ye a haminankow ye, wa o ko ninnu ka kan ka kɛ sariya ani tilennenya kɔnɔ", "Seydou"], | |
["Aw ni ce. Ne tɔgɔ ye Kaya Magan. Aw Sanbe Sanbe.", "Moussa"], | |
["An dɔlakelen bɛ masike bilenman don ka tɔw gɛn.", "Bourama"], | |
["Aw ni ce. Seidu bɛ aw fo wa aw ka yafa a ma, ka da a kan tuma dɔw la kow ka can.", "Modibo"], | |
] | |
def build_interface(): | |
"""Build the Gradio interface for Bambara TTS""" | |
with gr.Blocks(title="Bambara TTS - EXPERIMENTAL", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🎤 Bambara Text-to-Speech ⚠️ EXPERIMENTAL | |
Convert Bambara text to speech using AI. This model is currently experimental. | |
**Bambara** is spoken by millions of people in Mali and West Africa. | |
⚡ **Note**: Model loads automatically on first use and stays loaded for optimal performance. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Input section | |
text_input = gr.Textbox( | |
label="📝 Bambara Text", | |
placeholder="Type your Bambara text here...", | |
lines=3, | |
max_lines=6, | |
value="Aw ni ce" | |
) | |
speaker_dropdown = gr.Dropdown( | |
choices=SPEAKER_NAMES, | |
value="Adame", | |
label="🗣️ Speaker Voice" | |
) | |
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
use_advanced = gr.Checkbox( | |
label="⚙️ Use Advanced Settings", | |
value=False, | |
info="Enable to customize generation parameters" | |
) | |
with gr.Group(visible=False) as advanced_group: | |
gr.Markdown("**Advanced Parameters:**") | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.8, | |
step=0.1, | |
label="Temperature", | |
info="Higher = more varied" | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=5, | |
label="Top-K" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top-P" | |
) | |
max_tokens = gr.Slider( | |
minimum=256, | |
maximum=4096, | |
value=2048, | |
step=256, | |
label="Max Length" | |
) | |
gr.Markdown("### 🔊 Generated Audio") | |
audio_output = gr.Audio( | |
label="Generated Speech", | |
type="numpy", | |
interactive=False | |
) | |
status_output = gr.Textbox( | |
label="Status", | |
interactive=False, | |
show_label=False, | |
container=False | |
) | |
with gr.Accordion("📚 Try These Examples", open=True): | |
def load_example(text, speaker): | |
return text, speaker, False, 0.8, 50, 0.9, 2048 | |
gr.Markdown("**Click any example below:**") | |
for i, (text, speaker) in enumerate(examples): | |
btn = gr.Button(f"🎯 {text[:30]}{'...' if len(text) > 30 else ''}", size="sm") | |
btn.click( | |
fn=lambda t=text, s=speaker: load_example(t, s), | |
outputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens] | |
) | |
# Information section | |
with gr.Accordion("ℹ️ About", open=False): | |
gr.Markdown(""" | |
**⚠️ This is an experimental Bambara TTS model.** | |
- **Model**: Based on SparkTTS architecture with BiCodec | |
- **Languages**: Bambara (bm) | |
- **Speakers**: 5 different voice options | |
- **Sample Rate**: 16kHz | |
- **Architecture**: Neural codec with semantic and global tokens | |
## 🚀 How to Use | |
1. **Enter Text**: Type your Bambara text in the input box | |
2. **Choose Speaker**: Select from 5 available voice options | |
3. **Advanced Settings**: Optionally adjust generation parameters | |
4. **Generate**: Click the generate button to create speech | |
""") | |
def toggle_advanced(use_adv): | |
return gr.Group(visible=use_adv) | |
use_advanced.change( | |
fn=toggle_advanced, | |
inputs=[use_advanced], | |
outputs=[advanced_group] | |
) | |
generate_btn.click( | |
fn=generate_speech, | |
inputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens], | |
outputs=[audio_output, status_output], | |
show_progress=True | |
) | |
text_input.submit( | |
fn=generate_speech, | |
inputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens], | |
outputs=[audio_output, status_output], | |
show_progress=True | |
) | |
return demo | |
def main(): | |
"""Main function to launch the Gradio interface""" | |
logger.info("Starting Bambara TTS Gradio interface.") | |
interface = build_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |
logger.info("Gradio interface launched successfully.") | |
if __name__ == "__main__": | |
main() |