File size: 6,681 Bytes
3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 3f4dd11 fb7a6bf 032c2a6 fb7a6bf 3f4dd11 fb7a6bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import gradio as gr
from transformers import VitsModel, AutoTokenizer
import torch
import logging
import spaces
from typing import Tuple, Optional
import numpy as np
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if torch.cuda.is_available():
device = "cuda"
logger.info("Using CUDA for inference.")
elif torch.backends.mps.is_available():
device = "mps"
logger.info("Using MPS for inference.")
else:
device = "cpu"
logger.info("Using CPU for inference.")
languages = ["bambara", "boomu", "dogon", "pular", "songhoy", "tamasheq"]
examples = {
"bambara": "An filɛ ni ye yɔrɔ minna ni an ye an sigi ka a layɛ yala an bɛ ka baara min kɛ ɛsike a kɛlen don ka Ɲɛ wa ?",
"boomu": "Vunurobe wozomɛ pɛɛ, Poli we zo woro han Deeɓenu wara li Deeɓenu faralo zuun. Lo we baba a lo wara yi see ɓa Zuwifera ma ɓa Gɛrɛkela wa.",
"dogon": "Pɔɔlɔ, kubɔ lugo joo le, bana dɛin dɛin le, inɛw Ama titiyaanw le digɛu, Ama, emɛ babe bɛrɛ sɔɔ sɔi.",
"pular": "Miɗo ndaarde saabe Laamɗo e saabe Iisaa Almasiihu caroyoowo wuurɓe e maayɓe oo, miɗo ndaardire saabe gartol makko ka num e Laamu makko",
"songhoy": "Haya ka se beenediyo kokoyteraydi go hima nda huukoy foo ka fatta ja subaahi ka taasi goykoyyo ngu rezẽ faridi se",
"tamasheq": "Toḍă tăfukt ɣas, issăɣră-dd măssi-s n-ašĕkrĕš ănaẓraf-net, inn'-as: 'Ǝɣĕr-dd inaxdimăn, tĕẓlĕd-asăn, sănt s-wi dd-ĕšrăynen har tĕkkĕd wi dd-ăzzarnen."
}
class MalianTTS:
def __init__(self, model_name: str = "MALIBA-AI/malian-tts"):
self.model_name = model_name
self.models = {}
self.tokenizers = {}
self._load_models()
def _load_models(self):
"""Load all language models and tokenizers"""
try:
for lang in languages:
logger.info(f"Loading model and tokenizer for {lang}...")
self.models[lang] = VitsModel.from_pretrained(
self.model_name,
subfolder=f"models/{lang}"
).to(device)
self.tokenizers[lang] = AutoTokenizer.from_pretrained(
self.model_name,
subfolder=f"models/{lang}"
)
logger.info(f"Successfully loaded {lang}")
except Exception as e:
logger.error(f"Failed to load models: {str(e)}")
raise Exception(f"Model loading failed: {str(e)}")
def synthesize(self, language: str, text: str) -> Tuple[Optional[Tuple[int, np.ndarray]], Optional[str]]:
"""Generate audio from text for the specified language"""
if not text.strip():
return None, "Please enter some text to synthesize."
try:
model = self.models[language]
tokenizer = self.tokenizers[language]
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
output = model(**inputs).waveform
waveform = output.squeeze().cpu().numpy()
sample_rate = model.config.sampling_rate
return (sample_rate, waveform), None
except Exception as e:
logger.error(f"Error during inference for {language}: {str(e)}")
return None, f"Error generating audio: {str(e)}"
# Initialize the TTS system
tts_system = MalianTTS()
@spaces.GPU()
def generate_audio(language: str, text: str) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
"""
Generate audio from text using the specified language model.
"""
if not text.strip():
return None, "Please enter some text to synthesize."
try:
audio_output, error_msg = tts_system.synthesize(language, text)
if error_msg:
logger.error(f"TTS generation failed: {error_msg}")
return None, error_msg
logger.info(f"Successfully generated audio for {language}")
return audio_output, "Audio generated successfully!"
except Exception as e:
logger.error(f"Audio generation failed: {e}")
return None, f"Error: {str(e)}"
def load_example(language: str) -> str:
"""Load example text for the selected language"""
return examples.get(language, "No example available")
def build_interface():
"""
Builds the Gradio interface for Malian TTS.
"""
with gr.Blocks(title="MalianVoices") as demo:
gr.Markdown(
"""
# MalianVoices: 🇲🇱 Text-to-Speech in Six Malian Languages
Lightweight TTS for six Malian languages: **Bambara, Boomu, Dogon, Pular, Songhoy, Tamasheq**.
- ✅ Real-time TTS with fast response
## How to Use
1. Pick a language from the dropdown
2. Enter your text or load an example
3. Click **"Generate Audio"** to listen
"""
)
with gr.Row():
language = gr.Dropdown(
choices=languages,
label="Language",
value="bambara"
)
with gr.Column():
text = gr.Textbox(
label="Input Text",
lines=5,
placeholder="Type your text here..."
)
with gr.Row():
example_btn = gr.Button("Load Example")
generate_btn = gr.Button("Generate Audio", variant="primary")
audio_output = gr.Audio(label="Generated Audio", type="numpy")
status_msg = gr.Textbox(label="Status", interactive=False)
# Footer
gr.Markdown(
"""
By [sudoping01](https://huggingface.co/sudoping01), from [sudoping01/malian-tts](https://huggingface.co/sudoping01/malian-tts).
Fine-tuned on Meta's MMS, CC BY-NC 4.0, non-commercial.
"""
)
# Connect buttons to functions
generate_btn.click(
fn=generate_audio,
inputs=[language, text],
outputs=[audio_output, status_msg]
)
example_btn.click(
fn=load_example,
inputs=language,
outputs=text
)
return demo
if __name__ == "__main__":
logger.info("Starting the Gradio interface for MalianVoices TTS.")
interface = build_interface()
interface.launch()
logger.info("Gradio interface running.") |