sachin commited on
Commit
ef20fe6
·
1 Parent(s): 6aa9f3a
Files changed (1) hide show
  1. src/server/main.py +23 -45
src/server/main.py CHANGED
@@ -23,7 +23,6 @@ from tts_config import SPEED, ResponseFormat, config as tts_config
23
  from gemma_llm import LLMManager
24
  # from auth import get_api_key, settings as auth_settings
25
 
26
-
27
  import time
28
  from contextlib import asynccontextmanager
29
  from typing import Annotated, Any, OrderedDict, List
@@ -95,7 +94,6 @@ class TTSModelManager:
95
  if description_tokenizer.pad_token is None:
96
  description_tokenizer.pad_token = description_tokenizer.eos_token
97
 
98
-
99
  # TODO - temporary disable -torch.compile
100
  '''
101
  # Update model configuration
@@ -152,7 +150,6 @@ async def lifespan(_: FastAPI):
152
  tts_model_manager.get_or_load_model(config.model)
153
  yield
154
 
155
- #app = FastAPI(lifespan=lifespan)
156
  app = FastAPI(
157
  title="Dhwani API",
158
  description="AI Chat API supporting Indian languages",
@@ -161,7 +158,6 @@ app = FastAPI(
161
  lifespan=lifespan
162
  )
163
 
164
-
165
  def chunk_text(text, chunk_size):
166
  words = text.split()
167
  chunks = []
@@ -197,7 +193,6 @@ async def generate_audio(
197
  padding="max_length",
198
  max_length=tts_model_manager.max_length).to(device)
199
 
200
- # Use the tensor fields directly instead of BatchEncoding object
201
  input_ids = desc_inputs["input_ids"]
202
  attention_mask = desc_inputs["attention_mask"]
203
  prompt_input_ids = prompt_inputs["input_ids"]
@@ -323,7 +318,6 @@ async def generate_audio_batch(
323
 
324
  return StreamingResponse(in_memory_zip, media_type="application/zip")
325
 
326
-
327
  # Supported language codes
328
  SUPPORTED_LANGUAGES = {
329
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
@@ -354,7 +348,6 @@ class Settings(BaseSettings):
354
 
355
  settings = Settings()
356
 
357
-
358
  app.add_middleware(
359
  CORSMiddleware,
360
  allow_origins=["*"],
@@ -543,7 +536,7 @@ async def load_all_models():
543
  return {"status": "success", "message": "All models loaded"}
544
  except Exception as e:
545
  logger.error(f"Error loading models: {str(e)}")
546
- raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
547
 
548
  @app.post("/v1/translate", response_model=TranslationResponse)
549
  async def translate_endpoint(request: TranslationRequest):
@@ -567,13 +560,10 @@ async def chat(request: Request, chat_request: ChatRequest):
567
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
568
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
569
 
570
- # Define European languages that gemma-3-4b-it can handle natively
571
  EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
572
 
573
  try:
574
- # Check if the source language is Indian (requires translation) or European/English (direct processing)
575
  if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
576
- # Translate Indian language prompt to English
577
  translated_prompt = await perform_internal_translation(
578
  sentences=[chat_request.prompt],
579
  src_lang=chat_request.src_lang,
@@ -582,17 +572,13 @@ async def chat(request: Request, chat_request: ChatRequest):
582
  prompt_to_process = translated_prompt[0]
583
  logger.info(f"Translated prompt to English: {prompt_to_process}")
584
  else:
585
- # Use prompt directly for English and European languages
586
  prompt_to_process = chat_request.prompt
587
  logger.info("Prompt in English or European language, no translation needed")
588
 
589
- # Generate response with the LLM (assumed to handle multilingual input natively)
590
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
591
  logger.info(f"Generated response: {response}")
592
 
593
- # Check if the target language is Indian (requires translation) or European/English (direct output)
594
  if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
595
- # Translate response to Indian target language
596
  translated_response = await perform_internal_translation(
597
  sentences=[response],
598
  src_lang="eng_Latn",
@@ -601,7 +587,6 @@ async def chat(request: Request, chat_request: ChatRequest):
601
  final_response = translated_response[0]
602
  logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
603
  else:
604
- # Keep response as-is for English and European languages
605
  final_response = response
606
  logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
607
 
@@ -622,7 +607,6 @@ async def visual_query(
622
  if image.size == (0, 0):
623
  raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
624
 
625
- # Translate query to English if src_lang is not English
626
  if src_lang != "eng_Latn":
627
  translated_query = await perform_internal_translation(
628
  sentences=[query],
@@ -635,11 +619,9 @@ async def visual_query(
635
  query_to_process = query
636
  logger.info("Query already in English, no translation needed")
637
 
638
- # Generate response in English
639
  answer = await llm_manager.vision_query(image, query_to_process)
640
  logger.info(f"Generated English answer: {answer}")
641
 
642
- # Translate answer to target language if tgt_lang is not English
643
  if tgt_lang != "eng_Latn":
644
  translated_answer = await perform_internal_translation(
645
  sentences=[answer],
@@ -680,7 +662,6 @@ async def chat_v2(
680
  raise HTTPException(status_code=400, detail="Uploaded image is empty")
681
  img = Image.open(io.BytesIO(image_data))
682
 
683
- # Translate prompt to English if src_lang is not English
684
  if src_lang != "eng_Latn":
685
  translated_prompt = await perform_internal_translation(
686
  sentences=[prompt],
@@ -696,7 +677,6 @@ async def chat_v2(
696
  decoded = await llm_manager.chat_v2(img, prompt_to_process)
697
  logger.info(f"Generated English response: {decoded}")
698
 
699
- # Translate response to target language if tgt_lang is not English
700
  if tgt_lang != "eng_Latn":
701
  translated_response = await perform_internal_translation(
702
  sentences=[decoded],
@@ -709,7 +689,6 @@ async def chat_v2(
709
  final_response = decoded
710
  logger.info("Response kept in English, no translation needed")
711
  else:
712
- # Translate prompt to English if src_lang is not English
713
  if src_lang != "eng_Latn":
714
  translated_prompt = await perform_internal_translation(
715
  sentences=[prompt],
@@ -725,7 +704,6 @@ async def chat_v2(
725
  decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
726
  logger.info(f"Generated English response: {decoded}")
727
 
728
- # Translate response to target language if tgt_lang is not English
729
  if tgt_lang != "eng_Latn":
730
  translated_response = await perform_internal_translation(
731
  sentences=[decoded],
@@ -746,7 +724,6 @@ async def chat_v2(
746
  class TranscriptionResponse(BaseModel):
747
  text: str
748
 
749
-
750
  class ASRModelManager:
751
  def __init__(self, device_type="cuda"):
752
  self.device_type = device_type
@@ -758,7 +735,6 @@ class ASRModelManager:
758
  "telugu": "te", "urdu": "ur"
759
  }
760
 
761
-
762
  from fastapi import FastAPI, UploadFile
763
  import torch
764
  import torchaudio
@@ -774,27 +750,30 @@ from typing import List
774
  # Load the model
775
  model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
776
 
777
- asr_manager = ASRModelManager() # Load Kannada, Hindi, Tamil, Telugu, Malayalam
778
 
779
-
780
- #asr_manager = ASRModelManager(device_type="")
 
 
 
 
 
781
 
782
  @app.post("/transcribe/", response_model=TranscriptionResponse)
783
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
784
- # Load the uploaded audio file
785
- wav, sr = torchaudio.load(file.file)
786
- wav = torch.mean(wav, dim=0, keepdim=True)
787
-
788
- # Resample if necessary
789
- target_sample_rate = 16000 # Expected sample rate
790
- if sr != target_sample_rate:
791
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
792
- wav = resampler(wav)
793
-
794
- # Perform ASR with RNNT decoding using the provided language
795
- transcription_rnnt = model(wav, asr_manager.model_language[language], "rnnt")
796
-
797
- return JSONResponse(content={"text": transcription_rnnt})
798
 
799
  @app.post("/v1/speech_to_speech")
800
  async def speech_to_speech(
@@ -809,8 +788,8 @@ async def speech_to_speech(
809
  # Step 2: Process text with chat endpoint
810
  chat_request = ChatRequest(
811
  prompt=transcription.text,
812
- src_lang="kn_Knda", # Assuming script for Indian languages
813
- tgt_lang="kn_Knda"
814
  )
815
  processed_text = await chat(Request(), chat_request)
816
  logger.info(f"Processed text: {processed_text.response}")
@@ -825,7 +804,6 @@ async def speech_to_speech(
825
  )
826
  return audio_response
827
 
828
-
829
  class BatchTranscriptionResponse(BaseModel):
830
  transcriptions: List[str]
831
 
 
23
  from gemma_llm import LLMManager
24
  # from auth import get_api_key, settings as auth_settings
25
 
 
26
  import time
27
  from contextlib import asynccontextmanager
28
  from typing import Annotated, Any, OrderedDict, List
 
94
  if description_tokenizer.pad_token is None:
95
  description_tokenizer.pad_token = description_tokenizer.eos_token
96
 
 
97
  # TODO - temporary disable -torch.compile
98
  '''
99
  # Update model configuration
 
150
  tts_model_manager.get_or_load_model(config.model)
151
  yield
152
 
 
153
  app = FastAPI(
154
  title="Dhwani API",
155
  description="AI Chat API supporting Indian languages",
 
158
  lifespan=lifespan
159
  )
160
 
 
161
  def chunk_text(text, chunk_size):
162
  words = text.split()
163
  chunks = []
 
193
  padding="max_length",
194
  max_length=tts_model_manager.max_length).to(device)
195
 
 
196
  input_ids = desc_inputs["input_ids"]
197
  attention_mask = desc_inputs["attention_mask"]
198
  prompt_input_ids = prompt_inputs["input_ids"]
 
318
 
319
  return StreamingResponse(in_memory_zip, media_type="application/zip")
320
 
 
321
  # Supported language codes
322
  SUPPORTED_LANGUAGES = {
323
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
 
348
 
349
  settings = Settings()
350
 
 
351
  app.add_middleware(
352
  CORSMiddleware,
353
  allow_origins=["*"],
 
536
  return {"status": "success", "message": "All models loaded"}
537
  except Exception as e:
538
  logger.error(f"Error loading models: {str(e)}")
539
+ raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
540
 
541
  @app.post("/v1/translate", response_model=TranslationResponse)
542
  async def translate_endpoint(request: TranslationRequest):
 
560
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
561
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
562
 
 
563
  EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
564
 
565
  try:
 
566
  if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
 
567
  translated_prompt = await perform_internal_translation(
568
  sentences=[chat_request.prompt],
569
  src_lang=chat_request.src_lang,
 
572
  prompt_to_process = translated_prompt[0]
573
  logger.info(f"Translated prompt to English: {prompt_to_process}")
574
  else:
 
575
  prompt_to_process = chat_request.prompt
576
  logger.info("Prompt in English or European language, no translation needed")
577
 
 
578
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
579
  logger.info(f"Generated response: {response}")
580
 
 
581
  if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
 
582
  translated_response = await perform_internal_translation(
583
  sentences=[response],
584
  src_lang="eng_Latn",
 
587
  final_response = translated_response[0]
588
  logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
589
  else:
 
590
  final_response = response
591
  logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
592
 
 
607
  if image.size == (0, 0):
608
  raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
609
 
 
610
  if src_lang != "eng_Latn":
611
  translated_query = await perform_internal_translation(
612
  sentences=[query],
 
619
  query_to_process = query
620
  logger.info("Query already in English, no translation needed")
621
 
 
622
  answer = await llm_manager.vision_query(image, query_to_process)
623
  logger.info(f"Generated English answer: {answer}")
624
 
 
625
  if tgt_lang != "eng_Latn":
626
  translated_answer = await perform_internal_translation(
627
  sentences=[answer],
 
662
  raise HTTPException(status_code=400, detail="Uploaded image is empty")
663
  img = Image.open(io.BytesIO(image_data))
664
 
 
665
  if src_lang != "eng_Latn":
666
  translated_prompt = await perform_internal_translation(
667
  sentences=[prompt],
 
677
  decoded = await llm_manager.chat_v2(img, prompt_to_process)
678
  logger.info(f"Generated English response: {decoded}")
679
 
 
680
  if tgt_lang != "eng_Latn":
681
  translated_response = await perform_internal_translation(
682
  sentences=[decoded],
 
689
  final_response = decoded
690
  logger.info("Response kept in English, no translation needed")
691
  else:
 
692
  if src_lang != "eng_Latn":
693
  translated_prompt = await perform_internal_translation(
694
  sentences=[prompt],
 
704
  decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
705
  logger.info(f"Generated English response: {decoded}")
706
 
 
707
  if tgt_lang != "eng_Latn":
708
  translated_response = await perform_internal_translation(
709
  sentences=[decoded],
 
724
  class TranscriptionResponse(BaseModel):
725
  text: str
726
 
 
727
  class ASRModelManager:
728
  def __init__(self, device_type="cuda"):
729
  self.device_type = device_type
 
735
  "telugu": "te", "urdu": "ur"
736
  }
737
 
 
738
  from fastapi import FastAPI, UploadFile
739
  import torch
740
  import torchaudio
 
750
  # Load the model
751
  model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
752
 
753
+ asr_manager = ASRModelManager()
754
 
755
+ # Language to script mapping
756
+ LANGUAGE_TO_SCRIPT = {
757
+ "kannada": "kan_Knda", "hindi": "hin_Deva", "malayalam": "mal_Mlym", "tamil": "tam_Taml",
758
+ "telugu": "tel_Telu", "assamese": "asm_Beng", "bengali": "ben_Beng", "gujarati": "guj_Gujr",
759
+ "marathi": "mar_Deva", "odia": "ory_Orya", "punjabi": "pan_Guru", "urdu": "urd_Arab",
760
+ # Add more as needed
761
+ }
762
 
763
  @app.post("/transcribe/", response_model=TranscriptionResponse)
764
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
765
+ try:
766
+ wav, sr = torchaudio.load(file.file)
767
+ wav = torch.mean(wav, dim=0, keepdim=True)
768
+ target_sample_rate = 16000
769
+ if sr != target_sample_rate:
770
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
771
+ wav = resampler(wav)
772
+ transcription_rnnt = model(wav, asr_manager.model_language[language], "rnnt")
773
+ return TranscriptionResponse(text=transcription_rnnt)
774
+ except Exception as e:
775
+ logger.error(f"Error in transcription: {str(e)}")
776
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
 
777
 
778
  @app.post("/v1/speech_to_speech")
779
  async def speech_to_speech(
 
788
  # Step 2: Process text with chat endpoint
789
  chat_request = ChatRequest(
790
  prompt=transcription.text,
791
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), # Dynamic script mapping
792
+ tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
793
  )
794
  processed_text = await chat(Request(), chat_request)
795
  logger.info(f"Processed text: {processed_text.response}")
 
804
  )
805
  return audio_response
806
 
 
807
  class BatchTranscriptionResponse(BaseModel):
808
  transcriptions: List[str]
809