sachin commited on
Commit
0a0efec
·
1 Parent(s): 1936ef7

update tts

Browse files
Files changed (1) hide show
  1. src/server/main.py +31 -5
src/server/main.py CHANGED
@@ -69,10 +69,29 @@ class Settings(BaseSettings):
69
 
70
  settings = Settings()
71
 
72
- # TTS Setup
73
- tts_repo_id = "ai4bharat/IndicF5"
74
- tts_model = AutoModel.from_pretrained(tts_repo_id, trust_remote_code=True).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
76
  EXAMPLES = [
77
  {
78
  "audio_name": "KAN_F (Happy)",
@@ -99,7 +118,7 @@ def load_audio_from_url(url: str):
99
  return sample_rate, audio_data
100
  raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
101
 
102
- def synthesize_speech(text: str, ref_audio_name: str, ref_text: str):
103
  ref_audio_url = None
104
  for example in EXAMPLES:
105
  if example["audio_name"] == ref_audio_name:
@@ -119,7 +138,7 @@ def synthesize_speech(text: str, ref_audio_name: str, ref_text: str):
119
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
120
  sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
121
  temp_audio.flush()
122
- audio = tts_model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
123
 
124
  if audio.dtype == np.int16:
125
  audio = audio.astype(np.float32) / 32768.0
@@ -233,6 +252,7 @@ class ASRModelManager:
233
  llm_manager = LLMManager(settings.llm_model_name)
234
  model_manager = ModelManager()
235
  asr_manager = ASRModelManager()
 
236
  ip = IndicProcessor(inference=True)
237
 
238
  # Pydantic Models
@@ -278,6 +298,7 @@ async def lifespan(app: FastAPI):
278
  tasks = [
279
  asyncio.create_task(llm_manager.load()),
280
  asyncio.create_task(asr_manager.load()),
 
281
  asyncio.create_task(model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic')),
282
  asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
283
  asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
@@ -314,11 +335,14 @@ app.state.limiter = limiter
314
  # API Endpoints
315
  @app.post("/audio/speech", response_class=StreamingResponse)
316
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
 
 
317
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
318
  if not request.text.strip():
319
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
320
 
321
  audio_buffer = synthesize_speech(
 
322
  text=request.text,
323
  ref_audio_name="KAN_F (Happy)",
324
  ref_text=kannada_example["ref_text"]
@@ -610,6 +634,8 @@ async def speech_to_speech(
610
  file: UploadFile = File(...),
611
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
612
  ) -> StreamingResponse:
 
 
613
  transcription = await transcribe_audio(file, language)
614
  logger.info(f"Transcribed text: {transcription.text}")
615
 
 
69
 
70
  settings = Settings()
71
 
72
+ # TTS Manager
73
+ class TTSManager:
74
+ def __init__(self, device_type=device):
75
+ self.device_type = device_type
76
+ self.model = None
77
+ self.repo_id = "ai4bharat/IndicF5"
78
+
79
+ async def load(self):
80
+ logger.info("Loading TTS model IndicF5...")
81
+ self.model = await asyncio.to_thread(
82
+ AutoModel.from_pretrained,
83
+ self.repo_id,
84
+ trust_remote_code=True
85
+ )
86
+ self.model = self.model.to(self.device_type)
87
+ logger.info("TTS model IndicF5 loaded")
88
+
89
+ def synthesize(self, text, ref_audio_path, ref_text):
90
+ if not self.model:
91
+ raise ValueError("TTS model not loaded")
92
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
93
 
94
+ # TTS Constants
95
  EXAMPLES = [
96
  {
97
  "audio_name": "KAN_F (Happy)",
 
118
  return sample_rate, audio_data
119
  raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
120
 
121
+ def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
122
  ref_audio_url = None
123
  for example in EXAMPLES:
124
  if example["audio_name"] == ref_audio_name:
 
138
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
139
  sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
140
  temp_audio.flush()
141
+ audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
142
 
143
  if audio.dtype == np.int16:
144
  audio = audio.astype(np.float32) / 32768.0
 
252
  llm_manager = LLMManager(settings.llm_model_name)
253
  model_manager = ModelManager()
254
  asr_manager = ASRModelManager()
255
+ tts_manager = TTSManager()
256
  ip = IndicProcessor(inference=True)
257
 
258
  # Pydantic Models
 
298
  tasks = [
299
  asyncio.create_task(llm_manager.load()),
300
  asyncio.create_task(asr_manager.load()),
301
+ asyncio.create_task(tts_manager.load()),
302
  asyncio.create_task(model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic')),
303
  asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
304
  asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
 
335
  # API Endpoints
336
  @app.post("/audio/speech", response_class=StreamingResponse)
337
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
338
+ if not tts_manager.model:
339
+ raise HTTPException(status_code=503, detail="TTS model still loading, please try again later")
340
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
341
  if not request.text.strip():
342
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
343
 
344
  audio_buffer = synthesize_speech(
345
+ tts_manager,
346
  text=request.text,
347
  ref_audio_name="KAN_F (Happy)",
348
  ref_text=kannada_example["ref_text"]
 
634
  file: UploadFile = File(...),
635
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
636
  ) -> StreamingResponse:
637
+ if not tts_manager.model:
638
+ raise HTTPException(status_code=503, detail="TTS model still loading, please try again later")
639
  transcription = await transcribe_audio(file, language)
640
  logger.info(f"Transcribed text: {transcription.text}")
641