MalianTTS / app.py
sudoping01's picture
Update app.py
fb7a6bf verified
raw
history blame
6.68 kB
import gradio as gr
from transformers import VitsModel, AutoTokenizer
import torch
import logging
import spaces
from typing import Tuple, Optional
import numpy as np
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if torch.cuda.is_available():
device = "cuda"
logger.info("Using CUDA for inference.")
elif torch.backends.mps.is_available():
device = "mps"
logger.info("Using MPS for inference.")
else:
device = "cpu"
logger.info("Using CPU for inference.")
languages = ["bambara", "boomu", "dogon", "pular", "songhoy", "tamasheq"]
examples = {
"bambara": "An filɛ ni ye yɔrɔ minna ni an ye an sigi ka a layɛ yala an bɛ ka baara min kɛ ɛsike a kɛlen don ka Ɲɛ wa ?",
"boomu": "Vunurobe wozomɛ pɛɛ, Poli we zo woro han Deeɓenu wara li Deeɓenu faralo zuun. Lo we baba a lo wara yi see ɓa Zuwifera ma ɓa Gɛrɛkela wa.",
"dogon": "Pɔɔlɔ, kubɔ lugo joo le, bana dɛin dɛin le, inɛw Ama titiyaanw le digɛu, Ama, emɛ babe bɛrɛ sɔɔ sɔi.",
"pular": "Miɗo ndaarde saabe Laamɗo e saabe Iisaa Almasiihu caroyoowo wuurɓe e maayɓe oo, miɗo ndaardire saabe gartol makko ka num e Laamu makko",
"songhoy": "Haya ka se beenediyo kokoyteraydi go hima nda huukoy foo ka fatta ja subaahi ka taasi goykoyyo ngu rezẽ faridi se",
"tamasheq": "Toḍă tăfukt ɣas, issăɣră-dd măssi-s n-ašĕkrĕš ănaẓraf-net, inn'-as: 'Ǝɣĕr-dd inaxdimăn, tĕẓlĕd-asăn, sănt s-wi dd-ĕšrăynen har tĕkkĕd wi dd-ăzzarnen."
}
class MalianTTS:
def __init__(self, model_name: str = "MALIBA-AI/malian-tts"):
self.model_name = model_name
self.models = {}
self.tokenizers = {}
self._load_models()
def _load_models(self):
"""Load all language models and tokenizers"""
try:
for lang in languages:
logger.info(f"Loading model and tokenizer for {lang}...")
self.models[lang] = VitsModel.from_pretrained(
self.model_name,
subfolder=f"models/{lang}"
).to(device)
self.tokenizers[lang] = AutoTokenizer.from_pretrained(
self.model_name,
subfolder=f"models/{lang}"
)
logger.info(f"Successfully loaded {lang}")
except Exception as e:
logger.error(f"Failed to load models: {str(e)}")
raise Exception(f"Model loading failed: {str(e)}")
def synthesize(self, language: str, text: str) -> Tuple[Optional[Tuple[int, np.ndarray]], Optional[str]]:
"""Generate audio from text for the specified language"""
if not text.strip():
return None, "Please enter some text to synthesize."
try:
model = self.models[language]
tokenizer = self.tokenizers[language]
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
output = model(**inputs).waveform
waveform = output.squeeze().cpu().numpy()
sample_rate = model.config.sampling_rate
return (sample_rate, waveform), None
except Exception as e:
logger.error(f"Error during inference for {language}: {str(e)}")
return None, f"Error generating audio: {str(e)}"
# Initialize the TTS system
tts_system = MalianTTS()
@spaces.GPU()
def generate_audio(language: str, text: str) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
"""
Generate audio from text using the specified language model.
"""
if not text.strip():
return None, "Please enter some text to synthesize."
try:
audio_output, error_msg = tts_system.synthesize(language, text)
if error_msg:
logger.error(f"TTS generation failed: {error_msg}")
return None, error_msg
logger.info(f"Successfully generated audio for {language}")
return audio_output, "Audio generated successfully!"
except Exception as e:
logger.error(f"Audio generation failed: {e}")
return None, f"Error: {str(e)}"
def load_example(language: str) -> str:
"""Load example text for the selected language"""
return examples.get(language, "No example available")
def build_interface():
"""
Builds the Gradio interface for Malian TTS.
"""
with gr.Blocks(title="MalianVoices") as demo:
gr.Markdown(
"""
# MalianVoices: 🇲🇱 Text-to-Speech in Six Malian Languages
Lightweight TTS for six Malian languages: **Bambara, Boomu, Dogon, Pular, Songhoy, Tamasheq**.
- ✅ Real-time TTS with fast response
## How to Use
1. Pick a language from the dropdown
2. Enter your text or load an example
3. Click **"Generate Audio"** to listen
"""
)
with gr.Row():
language = gr.Dropdown(
choices=languages,
label="Language",
value="bambara"
)
with gr.Column():
text = gr.Textbox(
label="Input Text",
lines=5,
placeholder="Type your text here..."
)
with gr.Row():
example_btn = gr.Button("Load Example")
generate_btn = gr.Button("Generate Audio", variant="primary")
audio_output = gr.Audio(label="Generated Audio", type="numpy")
status_msg = gr.Textbox(label="Status", interactive=False)
# Footer
gr.Markdown(
"""
By [sudoping01](https://huggingface.co/sudoping01), from [sudoping01/malian-tts](https://huggingface.co/sudoping01/malian-tts).
Fine-tuned on Meta's MMS, CC BY-NC 4.0, non-commercial.
"""
)
# Connect buttons to functions
generate_btn.click(
fn=generate_audio,
inputs=[language, text],
outputs=[audio_output, status_msg]
)
example_btn.click(
fn=load_example,
inputs=language,
outputs=text
)
return demo
if __name__ == "__main__":
logger.info("Starting the Gradio interface for MalianVoices TTS.")
interface = build_interface()
interface.launch()
logger.info("Gradio interface running.")