sachin commited on
Commit
fea8b58
·
1 Parent(s): 4d3fcd9

fix-changes

Browse files
Files changed (1) hide show
  1. src/server/main.py +139 -71
src/server/main.py CHANGED
@@ -321,7 +321,7 @@ EXAMPLES = [
321
  {
322
  "audio_name": "KAN_F (Happy)",
323
  "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
324
- "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
325
  "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
326
  },
327
  ]
@@ -335,6 +335,12 @@ class SynthesizeRequest(BaseModel):
335
  class KannadaSynthesizeRequest(BaseModel):
336
  text: str
337
 
 
 
 
 
 
 
338
  # TTS Functions
339
  def load_audio_from_url(url: str):
340
  response = requests.get(url)
@@ -343,7 +349,7 @@ def load_audio_from_url(url: str):
343
  return sample_rate, audio_data
344
  raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
345
 
346
- def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
347
  ref_audio_url = None
348
  for example in EXAMPLES:
349
  if example["audio_name"] == ref_audio_name:
@@ -353,23 +359,26 @@ def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, r
353
  break
354
 
355
  if not ref_audio_url:
356
- raise HTTPException(status_code=400, detail="Invalid reference audio name.")
357
  if not text.strip():
358
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
359
  if not ref_text or not ref_text.strip():
360
- raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
361
 
 
362
  sample_rate, audio_data = load_audio_from_url(ref_audio_url)
363
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
 
364
  sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
365
  temp_audio.flush()
366
- audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
367
 
368
  if audio.dtype == np.int16:
369
  audio = audio.astype(np.float32) / 32768.0
370
  buffer = io.BytesIO()
371
  sf.write(buffer, audio, 24000, format='WAV')
372
  buffer.seek(0)
 
373
  return buffer
374
 
375
  # Supported languages
@@ -414,15 +423,13 @@ class ModelManager:
414
  self.device_type = device_type
415
  self.use_distilled = use_distilled
416
  self.is_lazy_loading = is_lazy_loading
417
- # Preload all translation models
418
  self.preload_models()
419
 
420
  def preload_models(self):
421
- # Define the core translation pairs to preload
422
  translation_pairs = [
423
- ('eng_Latn', 'kan_Knda', 'eng_indic'), # English to Indic
424
- ('kan_Knda', 'eng_Latn', 'indic_eng'), # Indic to English
425
- ('kan_Knda', 'hin_Deva', 'indic_indic') # Indic to Indic
426
  ]
427
  for src_lang, tgt_lang, key in translation_pairs:
428
  logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang}...")
@@ -545,24 +552,19 @@ translation_configs = []
545
  async def lifespan(app: FastAPI):
546
  def load_all_models():
547
  try:
548
- # Load LLM model
549
  logger.info("Loading LLM model...")
550
  llm_manager.load()
551
  logger.info("LLM model loaded successfully")
552
 
553
- # Load TTS model
554
  logger.info("Loading TTS model...")
555
  tts_manager.load()
556
  logger.info("TTS model loaded successfully")
557
 
558
- # Load ASR model
559
  logger.info("Loading ASR model...")
560
  asr_manager.load()
561
  logger.info("ASR model loaded successfully")
562
 
563
- # Translation models are preloaded in ModelManager constructor
564
  logger.info("Translation models already preloaded in ModelManager initialization.")
565
-
566
  logger.info("All models loaded successfully")
567
  except Exception as e:
568
  logger.error(f"Error loading models: {str(e)}")
@@ -776,16 +778,18 @@ async def chat_v2(
776
  # Include LLM Router
777
  app.include_router(llm_router)
778
 
779
- # Other API Endpoints
780
  @app.post("/audio/speech", response_class=StreamingResponse)
781
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
782
  if not tts_manager.model:
783
- raise HTTPException(status_code=503, detail="TTS model not loaded")
784
- kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
785
- if not request.text.strip():
786
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
787
 
788
- audio_buffer = synthesize_speech(
 
 
 
 
 
789
  tts_manager,
790
  text=request.text,
791
  ref_audio_name="KAN_F (Happy)",
@@ -798,6 +802,118 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
798
  headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
799
  )
800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  @app.post("/translate", response_model=TranslationResponse)
802
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
803
  if not request.sentences:
@@ -832,14 +948,6 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
832
  translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
833
  return TranslationResponse(translations=translations)
834
 
835
- @app.get("/v1/health")
836
- async def health_check():
837
- return {"status": "healthy", "model": settings.llm_model_name}
838
-
839
- @app.get("/")
840
- async def home():
841
- return RedirectResponse(url="/docs")
842
-
843
  @app.post("/v1/translate", response_model=TranslationResponse)
844
  async def translate_endpoint(request: TranslationRequest):
845
  logger.info(f"Received translation request: {request.dict()}")
@@ -855,46 +963,6 @@ async def translate_endpoint(request: TranslationRequest):
855
  logger.error(f"Unexpected error during translation: {str(e)}")
856
  raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
857
 
858
- @app.post("/transcribe/", response_model=TranscriptionResponse)
859
- async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
860
- if not asr_manager.model:
861
- raise HTTPException(status_code=503, detail="ASR model not loaded")
862
- try:
863
- wav, sr = torchaudio.load(file.file)
864
- wav = torch.mean(wav, dim=0, keepdim=True)
865
- target_sample_rate = 16000
866
- if sr != target_sample_rate:
867
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
868
- wav = resampler(wav)
869
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
870
- return TranscriptionResponse(text=transcription_rnnt)
871
- except Exception as e:
872
- logger.error(f"Error in transcription: {str(e)}")
873
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
874
-
875
- @app.post("/v1/speech_to_speech")
876
- async def speech_to_speech(
877
- request: Request,
878
- file: UploadFile = File(...),
879
- language: str = Query(..., enum=list(asr_manager.model_language.keys())),
880
- ) -> StreamingResponse:
881
- if not tts_manager.model:
882
- raise HTTPException(status_code=503, detail="TTS model not loaded")
883
- transcription = await transcribe_audio(file, language)
884
- logger.info(f"Transcribed text: {transcription.text}")
885
-
886
- chat_request = ChatRequest(
887
- prompt=transcription.text,
888
- src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
889
- tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
890
- )
891
- processed_text = await chat(request, chat_request)
892
- logger.info(f"Processed text: {processed_text.response}")
893
-
894
- voice_request = KannadaSynthesizeRequest(text=processed_text.response)
895
- audio_response = await synthesize_kannada(voice_request)
896
- return audio_response
897
-
898
  LANGUAGE_TO_SCRIPT = {
899
  "kannada": "kan_Knda"
900
  }
 
321
  {
322
  "audio_name": "KAN_F (Happy)",
323
  "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
324
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ。",
325
  "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
326
  },
327
  ]
 
335
  class KannadaSynthesizeRequest(BaseModel):
336
  text: str
337
 
338
+ @field_validator("text")
339
+ def text_must_be_valid(cls, v):
340
+ if len(v) > 500:
341
+ raise ValueError("Text cannot exceed 500 characters")
342
+ return v.strip()
343
+
344
  # TTS Functions
345
  def load_audio_from_url(url: str):
346
  response = requests.get(url)
 
349
  return sample_rate, audio_data
350
  raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
351
 
352
+ async def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str) -> io.BytesIO:
353
  ref_audio_url = None
354
  for example in EXAMPLES:
355
  if example["audio_name"] == ref_audio_name:
 
359
  break
360
 
361
  if not ref_audio_url:
362
+ raise HTTPException(status_code=400, detail=f"Invalid reference audio name: {ref_audio_name}")
363
  if not text.strip():
364
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty")
365
  if not ref_text or not ref_text.strip():
366
+ raise HTTPException(status_code=400, detail="Reference text cannot be empty")
367
 
368
+ logger.info(f"Synthesizing speech for text: {text[:50]}... with ref_audio: {ref_audio_name}")
369
  sample_rate, audio_data = load_audio_from_url(ref_audio_url)
370
+
371
+ async with await asyncio.to_thread(tempfile.NamedTemporaryFile, suffix=".wav", delete=False) as temp_audio:
372
  sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
373
  temp_audio.flush()
374
+ audio = await asyncio.to_thread(tts_manager.synthesize, text, temp_audio.name, ref_text)
375
 
376
  if audio.dtype == np.int16:
377
  audio = audio.astype(np.float32) / 32768.0
378
  buffer = io.BytesIO()
379
  sf.write(buffer, audio, 24000, format='WAV')
380
  buffer.seek(0)
381
+ logger.info("Speech synthesis completed")
382
  return buffer
383
 
384
  # Supported languages
 
423
  self.device_type = device_type
424
  self.use_distilled = use_distilled
425
  self.is_lazy_loading = is_lazy_loading
 
426
  self.preload_models()
427
 
428
  def preload_models(self):
 
429
  translation_pairs = [
430
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
431
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
432
+ ('kan_Knda', 'hin_Deva', 'indic_indic')
433
  ]
434
  for src_lang, tgt_lang, key in translation_pairs:
435
  logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang}...")
 
552
  async def lifespan(app: FastAPI):
553
  def load_all_models():
554
  try:
 
555
  logger.info("Loading LLM model...")
556
  llm_manager.load()
557
  logger.info("LLM model loaded successfully")
558
 
 
559
  logger.info("Loading TTS model...")
560
  tts_manager.load()
561
  logger.info("TTS model loaded successfully")
562
 
 
563
  logger.info("Loading ASR model...")
564
  asr_manager.load()
565
  logger.info("ASR model loaded successfully")
566
 
 
567
  logger.info("Translation models already preloaded in ModelManager initialization.")
 
568
  logger.info("All models loaded successfully")
569
  except Exception as e:
570
  logger.error(f"Error loading models: {str(e)}")
 
778
  # Include LLM Router
779
  app.include_router(llm_router)
780
 
781
+ # Improved Endpoints
782
  @app.post("/audio/speech", response_class=StreamingResponse)
783
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
784
  if not tts_manager.model:
785
+ raise HTTPException(status_code=503, detail="TTS model not loaded. Please load models via /v1/load_all_models.")
 
 
 
786
 
787
+ kannada_example = next((ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)"), None)
788
+ if not kannada_example:
789
+ raise HTTPException(status_code=500, detail="Reference audio configuration not found.")
790
+
791
+ logger.info(f"Received speech synthesis request for text: {request.text[:50]}...")
792
+ audio_buffer = await synthesize_speech(
793
  tts_manager,
794
  text=request.text,
795
  ref_audio_name="KAN_F (Happy)",
 
802
  headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
803
  )
804
 
805
+ @app.post("/transcribe/", response_model=TranscriptionResponse)
806
+ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
807
+ if not asr_manager.model:
808
+ raise HTTPException(status_code=503, detail="ASR model not loaded. Please load models via /v1/load_all_models.")
809
+
810
+ audio_data = await file.read()
811
+ if not audio_data:
812
+ raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
813
+ if len(audio_data) > 10 * 1024 * 1024: # 10MB limit
814
+ raise HTTPException(status_code=400, detail="Audio file exceeds 10MB limit")
815
+
816
+ logger.info(f"Transcribing audio file: {file.filename} in language: {language}")
817
+ try:
818
+ wav, sr = torchaudio.load(io.BytesIO(audio_data))
819
+ wav = torch.mean(wav, dim=0, keepdim=True)
820
+ target_sample_rate = 16000
821
+ if sr != target_sample_rate:
822
+ logger.info(f"Resampling audio from {sr}Hz to {target_sample_rate}Hz")
823
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
824
+ wav = resampler(wav)
825
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
826
+ logger.info(f"Transcription completed: {transcription_rnnt[:50]}...")
827
+ return TranscriptionResponse(text=transcription_rnnt)
828
+ except Exception as e:
829
+ logger.error(f"Error in transcription: {str(e)}")
830
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
831
+
832
+ async def transcribe_step(audio_data: bytes, language: str) -> str:
833
+ if not asr_manager.model:
834
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
835
+ wav, sr = torchaudio.load(io.BytesIO(audio_data))
836
+ wav = torch.mean(wav, dim=0, keepdim=True)
837
+ target_sample_rate = 16000
838
+ if sr != target_sample_rate:
839
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
840
+ wav = resampler(wav)
841
+ return asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
842
+
843
+ async def synthesize_step(text: str) -> io.BytesIO:
844
+ kannada_example = next((ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)"), None)
845
+ if not kannada_example:
846
+ raise HTTPException(status_code=500, detail="Reference audio configuration not found")
847
+ return await synthesize_speech(tts_manager, text, "KAN_F (Happy)", kannada_example["ref_text"])
848
+
849
+ @app.post("/v1/speech_to_speech", response_class=StreamingResponse)
850
+ async def speech_to_speech(
851
+ request: Request,
852
+ file: UploadFile = File(...),
853
+ language: str = Query(..., enum=list(asr_manager.model_language.keys())),
854
+ ):
855
+ if not tts_manager.model or not asr_manager.model:
856
+ raise HTTPException(status_code=503, detail="TTS or ASR model not loaded. Please load models via /v1/load_all_models.")
857
+
858
+ audio_data = await file.read()
859
+ if not audio_data:
860
+ raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
861
+ if len(audio_data) > 10 * 1024 * 1024: # 10MB limit
862
+ raise HTTPException(status_code=400, detail="Audio file exceeds 10MB limit")
863
+
864
+ logger.info(f"Processing speech-to-speech for file: {file.filename} in language: {language}")
865
+ try:
866
+ # Step 1: Transcribe
867
+ transcription = await transcribe_step(audio_data, language)
868
+ logger.info(f"Transcribed text: {transcription[:50]}...")
869
+
870
+ # Step 2: Process with LLM
871
+ chat_request = ChatRequest(
872
+ prompt=transcription,
873
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
874
+ tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
875
+ )
876
+ processed_text = await chat(request, chat_request)
877
+ logger.info(f"Processed text: {processed_text.response[:50]}...")
878
+
879
+ # Step 3: Synthesize
880
+ audio_buffer = await synthesize_step(processed_text.response)
881
+ logger.info("Speech-to-speech processing completed")
882
+
883
+ return StreamingResponse(
884
+ audio_buffer,
885
+ media_type="audio/wav",
886
+ headers={"Content-Disposition": "attachment; filename=speech_to_speech_output.wav"}
887
+ )
888
+ except Exception as e:
889
+ logger.error(f"Error in speech-to-speech pipeline: {str(e)}")
890
+ raise HTTPException(status_code=500, detail=f"Speech-to-speech failed: {str(e)}")
891
+
892
+ @app.get("/v1/health")
893
+ async def health_check():
894
+ status = {
895
+ "status": "healthy",
896
+ "model": settings.llm_model_name,
897
+ "llm_loaded": llm_manager.is_loaded,
898
+ "tts_loaded": bool(tts_manager.model),
899
+ "asr_loaded": bool(asr_manager.model),
900
+ "translation_models": list(model_manager.models.keys()),
901
+ "device": device,
902
+ "cuda_available": cuda_available,
903
+ "cuda_version": cuda_version if cuda_available else "N/A"
904
+ }
905
+ logger.info("Health check requested")
906
+ return status
907
+
908
+ @app.get("/")
909
+ async def home():
910
+ logger.info("Root endpoint accessed, redirecting to docs")
911
+ return JSONResponse(
912
+ content={"message": "Welcome to Dhwani API! Redirecting to documentation..."},
913
+ headers={"Location": "/docs"},
914
+ status_code=302
915
+ )
916
+
917
  @app.post("/translate", response_model=TranslationResponse)
918
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
919
  if not request.sentences:
 
948
  translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
949
  return TranslationResponse(translations=translations)
950
 
 
 
 
 
 
 
 
 
951
  @app.post("/v1/translate", response_model=TranslationResponse)
952
  async def translate_endpoint(request: TranslationRequest):
953
  logger.info(f"Received translation request: {request.dict()}")
 
963
  logger.error(f"Unexpected error during translation: {str(e)}")
964
  raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
965
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
966
  LANGUAGE_TO_SCRIPT = {
967
  "kannada": "kan_Knda"
968
  }