sudoping01 commited on
Commit
e849c49
·
verified ·
1 Parent(s): d312b52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -32
app.py CHANGED
@@ -12,11 +12,11 @@ import spaces
12
  import logging
13
  from huggingface_hub import login
14
  import threading
 
15
 
16
  torch._dynamo.config.disable = True
17
  torch._dynamo.config.suppress_errors = True
18
 
19
-
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
@@ -25,44 +25,68 @@ if hf_token:
25
  login(token=hf_token)
26
 
27
 
28
- tts_model = None
29
- speakers_dict = None
30
- model_initialized = False
31
- model_initialized_lock = threading.Lock()
32
-
33
- @spaces.GPU()
34
- def initialize_model():
35
- """Initialize the TTS model and speakers - called once with GPU context"""
36
- global tts_model, speakers_dict, model_initialized, model_initialized_lock
37
 
38
- # Always acquire lock first for async safety
39
- with model_initialized_lock:
40
- if not model_initialized:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  logger.info("Initializing Bambara TTS model...")
 
42
 
43
  try:
44
  from maliba_ai.tts.inference import BambaraTTSInference
45
  from maliba_ai.config.speakers import Adame, Moussa, Bourama, Modibo, Seydou
46
 
47
- # All initialization inside the lock
48
- tts_model = BambaraTTSInference()
49
- speakers_dict = {
50
- "Adame": Adame,
51
  "Moussa": Moussa,
52
  "Bourama": Bourama,
53
  "Modibo": Modibo,
54
  "Seydou": Seydou
55
  }
56
 
57
- # Set flag last, after everything is ready
58
- model_initialized = True
59
- logger.info("Model initialized successfully!")
60
 
61
  except Exception as e:
62
  logger.error(f"Failed to initialize model: {e}")
63
  raise e
 
 
64
 
65
- return tts_model, speakers_dict
 
 
 
 
 
 
 
66
 
67
  def validate_inputs(text, temperature, top_k, top_p, max_tokens):
68
  if not text or not text.strip():
@@ -81,14 +105,12 @@ def validate_inputs(text, temperature, top_k, top_p, max_tokens):
81
 
82
  @spaces.GPU()
83
  def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p, max_tokens):
84
-
85
-
86
  if not text.strip():
87
  return None, "Please enter some Bambara text."
88
 
89
  try:
90
-
91
- tts, speakers = initialize_model()
92
 
93
  speaker = speakers[speaker_name]
94
 
@@ -121,6 +143,15 @@ def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p,
121
  logger.error(f"Speech generation failed: {e}")
122
  return None, f"❌ Error: {str(e)}"
123
 
 
 
 
 
 
 
 
 
 
124
 
125
  SPEAKER_NAMES = ["Adame", "Moussa", "Bourama", "Modibo", "Seydou"]
126
 
@@ -129,12 +160,11 @@ examples = [
129
  ["Mali bɛna diya kɔsɛbɛ, ka a da a kan baara bɛ ka kɛ.", "Moussa"],
130
  ["Ne bɛ se ka sɛbɛnni yɛlɛma ka kɛ kuma ye", "Bourama"],
131
  ["I ka kɛnɛ wa?", "Modibo"],
132
- ["Lakɔli karamɔgɔw tun tɛ ka se ka sɛbɛnni kɛ ka ɲɛ walanba kan wa denmisɛnw tun tɛ ka se ka o sɛbɛnni ninnu ye, kuma tɛ ka u kalan. Denmisɛnw kɛra kunfinw ye.", "Adame"],
133
  ["sigikafɔ kɔnɔ jamanaw ni ɲɔgɔn cɛ, olu ye a haminankow ye, wa o ko ninnu ka kan ka kɛ sariya ani tilennenya kɔnɔ.", "Seydou"],
134
  ["Aw ni ce. Ne tɔgɔ ye Adama. Awɔ, ne ye maliden de ye. Aw Sanbɛ Sanbɛ. San min tɛ ɲinan ye, an bɛɛ ka jɛ ka o seli ɲɔgɔn fɛ, hɛɛrɛ ni lafiya la. Ala ka Mali suma. Ala ka Mali yiriwa. Ala ka Mali taa ɲɛ. Ala ka an ka seliw caya. Ala ka yafa an bɛɛ ma.", "Moussa"],
135
  ["An dɔlakelen bɛ masike bilenman don ka tɔw gɛn.", "Bourama"],
136
  ["Aw ni ce. Seidu bɛ aw fo wa aw ka yafa a ma, ka da a kan tuma dɔw la kow ka can.", "Modibo"],
137
-
138
  ]
139
 
140
  def build_interface():
@@ -149,11 +179,8 @@ def build_interface():
149
  Convert Bambara text to speech. This model is currently experimental.
150
 
151
  **Bambara** is spoken by millions of people in Mali and West Africa.
152
- .
153
  """)
154
 
155
-
156
-
157
  with gr.Row():
158
  with gr.Column(scale=2):
159
  text_input = gr.Textbox(
@@ -237,18 +264,20 @@ def build_interface():
237
  gr.Markdown("**Click any example below:**")
238
 
239
  for i, (text, speaker) in enumerate(examples):
240
- btn = gr.Button(f" {text[:30]}{'...' if len(text) > 30 else ''}", size="sm")
241
  btn.click(
242
  fn=lambda t=text, s=speaker: load_example(t, s),
243
  outputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens]
244
  )
245
 
246
- with gr.Accordion(" About", open=False):
247
  gr.Markdown("""
248
  **⚠️ This is an experimental Bambara TTS model.**
249
  - **Languages**: Bambara (bm)
250
  - **Speakers**: 5 different voice options
251
  - **Sample Rate**: 16kHz
 
 
252
  """)
253
 
254
  def toggle_advanced(use_adv):
@@ -280,6 +309,8 @@ def main():
280
  """Main function to launch the Gradio interface"""
281
  logger.info("Starting Bambara TTS Gradio interface.")
282
 
 
 
283
  interface = build_interface()
284
  interface.launch(
285
  server_name="0.0.0.0",
 
12
  import logging
13
  from huggingface_hub import login
14
  import threading
15
+ import time
16
 
17
  torch._dynamo.config.disable = True
18
  torch._dynamo.config.suppress_errors = True
19
 
 
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
 
25
  login(token=hf_token)
26
 
27
 
28
+ class ModelSingleton:
29
+ _instance = None
30
+ _lock = threading.Lock()
 
 
 
 
 
 
31
 
32
+ def __new__(cls):
33
+ if cls._instance is None:
34
+ with cls._lock:
35
+ if cls._instance is None:
36
+ cls._instance = super(ModelSingleton, cls).__new__(cls)
37
+ cls._instance.initialized = False
38
+ cls._instance.tts_model = None
39
+ cls._instance.speakers_dict = None
40
+ cls._instance.init_lock = threading.RLock()
41
+ return cls._instance
42
+
43
+ @spaces.GPU()
44
+ def initialize(self):
45
+ """Thread-safe initialization with singleton pattern"""
46
+ if self.initialized:
47
+ logger.info("Model already initialized, skipping...")
48
+ return self.tts_model, self.speakers_dict
49
+
50
+ with self.init_lock:
51
+ # Double-check pattern
52
+ if self.initialized:
53
+ logger.info("Model already initialized (double-check), skipping...")
54
+ return self.tts_model, self.speakers_dict
55
+
56
  logger.info("Initializing Bambara TTS model...")
57
+ start_time = time.time()
58
 
59
  try:
60
  from maliba_ai.tts.inference import BambaraTTSInference
61
  from maliba_ai.config.speakers import Adame, Moussa, Bourama, Modibo, Seydou
62
 
63
+ self.tts_model = BambaraTTSInference()
64
+ self.speakers_dict = {
65
+ "Adama": Adame,
 
66
  "Moussa": Moussa,
67
  "Bourama": Bourama,
68
  "Modibo": Modibo,
69
  "Seydou": Seydou
70
  }
71
 
72
+ self.initialized = True
73
+ elapsed = time.time() - start_time
74
+ logger.info(f"Model initialized successfully in {elapsed:.2f} seconds!")
75
 
76
  except Exception as e:
77
  logger.error(f"Failed to initialize model: {e}")
78
  raise e
79
+
80
+ return self.tts_model, self.speakers_dict
81
 
82
+ def get_model(self):
83
+ """Get the model, initializing if needed"""
84
+ if not self.initialized:
85
+ return self.initialize()
86
+ return self.tts_model, self.speakers_dict
87
+
88
+ # Global singleton instance
89
+ model_singleton = ModelSingleton()
90
 
91
  def validate_inputs(text, temperature, top_k, top_p, max_tokens):
92
  if not text or not text.strip():
 
105
 
106
  @spaces.GPU()
107
  def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p, max_tokens):
 
 
108
  if not text.strip():
109
  return None, "Please enter some Bambara text."
110
 
111
  try:
112
+ # Get model through singleton
113
+ tts, speakers = model_singleton.get_model()
114
 
115
  speaker = speakers[speaker_name]
116
 
 
143
  logger.error(f"Speech generation failed: {e}")
144
  return None, f"❌ Error: {str(e)}"
145
 
146
+ # Preload model on startup (optional - comment out if you prefer lazy loading)
147
+ def preload_model():
148
+ """Preload the model when the app starts"""
149
+ try:
150
+ logger.info("Preloading model...")
151
+ model_singleton.initialize()
152
+ logger.info("Model preloaded successfully!")
153
+ except Exception as e:
154
+ logger.error(f"Failed to preload model: {e}")
155
 
156
  SPEAKER_NAMES = ["Adame", "Moussa", "Bourama", "Modibo", "Seydou"]
157
 
 
160
  ["Mali bɛna diya kɔsɛbɛ, ka a da a kan baara bɛ ka kɛ.", "Moussa"],
161
  ["Ne bɛ se ka sɛbɛnni yɛlɛma ka kɛ kuma ye", "Bourama"],
162
  ["I ka kɛnɛ wa?", "Modibo"],
163
+ ["Lakɔli karamɔgɔw tun tɛ ka se ka sɛbɛnni kɛ ka ɲɛ walanba kan wa denmisɛnw tun tɛ ka se ka o sɛbɛnni ninnu ye, kuma tɛ ka u kalan. Denmisɛnw kɛra kunfinw ye.", "Adama"],
164
  ["sigikafɔ kɔnɔ jamanaw ni ɲɔgɔn cɛ, olu ye a haminankow ye, wa o ko ninnu ka kan ka kɛ sariya ani tilennenya kɔnɔ.", "Seydou"],
165
  ["Aw ni ce. Ne tɔgɔ ye Adama. Awɔ, ne ye maliden de ye. Aw Sanbɛ Sanbɛ. San min tɛ ɲinan ye, an bɛɛ ka jɛ ka o seli ɲɔgɔn fɛ, hɛɛrɛ ni lafiya la. Ala ka Mali suma. Ala ka Mali yiriwa. Ala ka Mali taa ɲɛ. Ala ka an ka seliw caya. Ala ka yafa an bɛɛ ma.", "Moussa"],
166
  ["An dɔlakelen bɛ masike bilenman don ka tɔw gɛn.", "Bourama"],
167
  ["Aw ni ce. Seidu bɛ aw fo wa aw ka yafa a ma, ka da a kan tuma dɔw la kow ka can.", "Modibo"],
 
168
  ]
169
 
170
  def build_interface():
 
179
  Convert Bambara text to speech. This model is currently experimental.
180
 
181
  **Bambara** is spoken by millions of people in Mali and West Africa.
 
182
  """)
183
 
 
 
184
  with gr.Row():
185
  with gr.Column(scale=2):
186
  text_input = gr.Textbox(
 
264
  gr.Markdown("**Click any example below:**")
265
 
266
  for i, (text, speaker) in enumerate(examples):
267
+ btn = gr.Button(f"{text[:30]}{'...' if len(text) > 30 else ''}", size="sm")
268
  btn.click(
269
  fn=lambda t=text, s=speaker: load_example(t, s),
270
  outputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens]
271
  )
272
 
273
+ with gr.Accordion("About", open=False):
274
  gr.Markdown("""
275
  **⚠️ This is an experimental Bambara TTS model.**
276
  - **Languages**: Bambara (bm)
277
  - **Speakers**: 5 different voice options
278
  - **Sample Rate**: 16kHz
279
+
280
+ **Status**: Model loads once and reuses for all requests
281
  """)
282
 
283
  def toggle_advanced(use_adv):
 
309
  """Main function to launch the Gradio interface"""
310
  logger.info("Starting Bambara TTS Gradio interface.")
311
 
312
+ preload_model()
313
+
314
  interface = build_interface()
315
  interface.launch(
316
  server_name="0.0.0.0",