mutisya commited on
Commit
9a09ff1
·
verified ·
1 Parent(s): 7178214

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -26
app.py CHANGED
@@ -35,30 +35,6 @@ def get_translation_pipeline(translation_model_path):
35
 
36
  translator = get_translation_pipeline("mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4")
37
 
38
-
39
- def load_tts_model(model_id):
40
- model_pipeline = pipeline("text-to-speech", model=model_id, device=device)
41
- return model_pipeline
42
-
43
- def initialize_tts_pipelines(load_models=False):
44
- global tts_config_settings
45
- global tts_pipelines
46
- with open(f"tts_models_config.json") as f:
47
- tts_config_settings = json.loads(f.read())
48
-
49
- for lang, lang_config in tts_config_settings.items():
50
- if lang in tts_preload_languages or load_models:
51
- tts_pipelines[lang] = load_tts_model(lang_config["model_repo"])
52
-
53
- def ensure_tts_pipeline_loaded(lang_code):
54
- global tts_config_settings
55
- global tts_pipelines
56
- if lang_code in tts_pipelines:
57
- pipeline = tts_pipelines[lang_code]
58
- else:
59
- lang_config = tts_config_settings[lang_code]
60
- tts_pipelines[lang_code] = load_tts_model(lang_config["model_repo"])
61
-
62
  def load_asr_model(model_id):
63
  model_pipeline = pipeline("automatic-speech-recognition", model=model_id, device=device)
64
  return model_pipeline
@@ -175,6 +151,29 @@ tts_config_settings = {}
175
  tts_pipelines={}
176
  tts_preload_languages=["kik"]
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  @app.post("/text-to-speech", response_model=TTSResponse)
179
  async def text_to_speech(request: TTSRequest):
180
  """
@@ -192,9 +191,11 @@ async def text_to_speech(request: TTSRequest):
192
 
193
  ensure_tts_pipeline_loaded(language)
194
  tts_pipeline = tts_pipelines[language]
 
195
 
196
- audio = tts_pipeline(text, return_tensors=True)["waveform"]
197
- sample_rate = 22050 # Default sample rate for the espnet model
 
198
 
199
  # Save the audio to a BytesIO buffer as a WAV file
200
  buffer = io.BytesIO()
@@ -209,5 +210,8 @@ async def text_to_speech(request: TTSRequest):
209
  raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
210
 
211
  # Run the FastAPI application
 
 
 
212
  if __name__ == "__main__":
213
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
35
 
36
  translator = get_translation_pipeline("mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def load_asr_model(model_id):
39
  model_pipeline = pipeline("automatic-speech-recognition", model=model_id, device=device)
40
  return model_pipeline
 
151
  tts_pipelines={}
152
  tts_preload_languages=["kik"]
153
 
154
+ def load_tts_model(model_id):
155
+ model_pipeline = pipeline("text-to-speech", model=model_id, device=device)
156
+ return model_pipeline
157
+
158
+ def initialize_tts_pipelines(load_models=False):
159
+ global tts_config_settings
160
+ global tts_pipelines
161
+ with open(f"tts_models_config.json") as f:
162
+ tts_config_settings = json.loads(f.read())
163
+
164
+ for lang, lang_config in tts_config_settings.items():
165
+ if lang in tts_preload_languages or load_models:
166
+ tts_pipelines[lang] = load_tts_model(lang_config["model_repo"])
167
+
168
+ def ensure_tts_pipeline_loaded(lang_code):
169
+ global tts_config_settings
170
+ global tts_pipelines
171
+ if lang_code in tts_pipelines:
172
+ pipeline = tts_pipelines[lang_code]
173
+ else:
174
+ lang_config = tts_config_settings[lang_code]
175
+ tts_pipelines[lang_code] = load_tts_model(lang_config["model_repo"])
176
+
177
  @app.post("/text-to-speech", response_model=TTSResponse)
178
  async def text_to_speech(request: TTSRequest):
179
  """
 
191
 
192
  ensure_tts_pipeline_loaded(language)
193
  tts_pipeline = tts_pipelines[language]
194
+ print("Received request: "+ text)
195
 
196
+ #audio = tts_pipeline(text, return_tensors=True)["waveform"]
197
+ audio = tts_pipeline(text)
198
+ sample_rate = 16000 # Default sample rate for the espnet model
199
 
200
  # Save the audio to a BytesIO buffer as a WAV file
201
  buffer = io.BytesIO()
 
210
  raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
211
 
212
  # Run the FastAPI application
213
+ initialize_tts_pipelines(True)
214
+ initialize_asr_pipelines()
215
+
216
  if __name__ == "__main__":
217
  uvicorn.run(app, host="0.0.0.0", port=7860)