sachin commited on
Commit
abca105
·
1 Parent(s): fd114d7
Files changed (1) hide show
  1. src/server/main.py +56 -111
src/server/main.py CHANGED
@@ -21,27 +21,14 @@ from IndicTransToolkit import IndicProcessor
21
  from logging_config import logger
22
  from tts_config import SPEED, ResponseFormat, config as tts_config
23
  from gemma_llm import LLMManager
24
- # from auth import get_api_key, settings as auth_settings
25
-
26
 
27
  import time
28
  from contextlib import asynccontextmanager
29
- from typing import Annotated, Any, OrderedDict, List
30
  import zipfile
31
  import soundfile as sf
32
- import torch
33
- from fastapi import Body, FastAPI, HTTPException, Response
34
- from parler_tts import ParlerTTSForConditionalGeneration
35
- from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
36
  import numpy as np
37
  from config import SPEED, ResponseFormat, config
38
- from logger import logger
39
- import uvicorn
40
- import argparse
41
- from fastapi.responses import RedirectResponse, StreamingResponse
42
- import io
43
- import os
44
- import logging
45
 
46
  # Device setup
47
  if torch.cuda.is_available():
@@ -89,29 +76,13 @@ class TTSModelManager:
89
  tokenizer = AutoTokenizer.from_pretrained(model_name)
90
  description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
91
 
92
- # Set pad tokens
93
  if tokenizer.pad_token is None:
94
  tokenizer.pad_token = tokenizer.eos_token
95
  if description_tokenizer.pad_token is None:
96
  description_tokenizer.pad_token = description_tokenizer.eos_token
97
-
98
 
99
- # TODO - temporary disable -torch.compile
100
- '''
101
- # Update model configuration
102
- model.config.pad_token_id = tokenizer.pad_token_id
103
- # Update for deprecation: use max_batch_size instead of batch_size
104
- if hasattr(model.generation_config.cache_config, 'max_batch_size'):
105
- model.generation_config.cache_config.max_batch_size = 1
106
- model.generation_config.cache_implementation = "static"
107
- '''
108
- # Compile the model
109
- compile_mode = "default"
110
- #compile_mode = "reduce-overhead"
111
-
112
- model.forward = torch.compile(model.forward, mode=compile_mode)
113
 
114
- # Warmup
115
  warmup_inputs = tokenizer("Warmup text for compilation",
116
  return_tensors="pt",
117
  padding="max_length",
@@ -124,8 +95,7 @@ class TTSModelManager:
124
  "prompt_attention_mask": warmup_inputs["attention_mask"],
125
  }
126
 
127
- n_steps = 1 if compile_mode == "default" else 2
128
- for _ in range(n_steps):
129
  _ = model.generate(**model_kwargs)
130
 
131
  logger.info(
@@ -152,16 +122,14 @@ async def lifespan(_: FastAPI):
152
  tts_model_manager.get_or_load_model(config.model)
153
  yield
154
 
155
- #app = FastAPI(lifespan=lifespan)
156
  app = FastAPI(
157
  title="Dhwani API",
158
- description="AI Chat API supporting Indian languages",
159
  version="1.0.0",
160
  redirect_slashes=False,
161
  lifespan=lifespan
162
  )
163
 
164
-
165
  def chunk_text(text, chunk_size):
166
  words = text.split()
167
  chunks = []
@@ -197,7 +165,6 @@ async def generate_audio(
197
  padding="max_length",
198
  max_length=tts_model_manager.max_length).to(device)
199
 
200
- # Use the tensor fields directly instead of BatchEncoding object
201
  input_ids = desc_inputs["input_ids"]
202
  attention_mask = desc_inputs["attention_mask"]
203
  prompt_input_ids = prompt_inputs["input_ids"]
@@ -323,14 +290,23 @@ async def generate_audio_batch(
323
 
324
  return StreamingResponse(in_memory_zip, media_type="application/zip")
325
 
326
-
327
  # Supported language codes
328
  SUPPORTED_LANGUAGES = {
 
329
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
330
  "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
331
  "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
332
  "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
333
- "kan_Knda", "ory_Orya"
 
 
 
 
 
 
 
 
 
334
  }
335
 
336
  class Settings(BaseSettings):
@@ -352,7 +328,6 @@ class Settings(BaseSettings):
352
 
353
  settings = Settings()
354
 
355
-
356
  app.add_middleware(
357
  CORSMiddleware,
358
  allow_origins=["*"],
@@ -366,7 +341,6 @@ app.state.limiter = limiter
366
 
367
  llm_manager = LLMManager(settings.llm_model_name)
368
 
369
- # Translation Manager and Model Manager
370
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
371
 
372
  class TranslateManager:
@@ -382,7 +356,7 @@ class TranslateManager:
382
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
383
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
384
  else:
385
- raise ValueError("Invalid language combination: English to English translation is not supported.")
386
 
387
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
388
  model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -415,7 +389,7 @@ class ModelManager:
415
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
416
  key = 'indic_indic'
417
  else:
418
- raise ValueError("Invalid language combination: English to English translation is not supported.")
419
 
420
  if key not in self.models:
421
  if self.is_lazy_loading:
@@ -432,11 +406,10 @@ class ModelManager:
432
  ip = IndicProcessor(inference=True)
433
  model_manager = ModelManager()
434
 
435
- # Pydantic Models
436
  class ChatRequest(BaseModel):
437
  prompt: str
438
- src_lang: str = "kan_Knda" # Default to Kannada
439
- tgt_lang: str = "kan_Knda" # Default to Kannada
440
 
441
  @field_validator("prompt")
442
  def prompt_must_be_valid(cls, v):
@@ -461,11 +434,9 @@ class TranslationRequest(BaseModel):
461
  class TranslationResponse(BaseModel):
462
  translations: List[str]
463
 
464
- # Dependency to get TranslateManager
465
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
466
  return model_manager.get_model(src_lang, tgt_lang)
467
 
468
- # Internal Translation Endpoint
469
  @app.post("/translate", response_model=TranslationResponse)
470
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
471
  input_sentences = request.sentences
@@ -505,14 +476,12 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
505
  translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
506
  return TranslationResponse(translations=translations)
507
 
508
- # Helper function to perform internal translation
509
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
510
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
511
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
512
  response = await translate(request, translate_manager)
513
  return response.translations
514
 
515
- # API Endpoints
516
  @app.get("/v1/health")
517
  async def health_check():
518
  return {"status": "healthy", "model": settings.llm_model_name}
@@ -564,9 +533,14 @@ async def chat(request: Request, chat_request: ChatRequest):
564
  if not chat_request.prompt:
565
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
566
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
 
567
  try:
568
- # Translate prompt to English if src_lang is not English
569
- if chat_request.src_lang != "eng_Latn":
 
 
 
 
570
  translated_prompt = await perform_internal_translation(
571
  sentences=[chat_request.prompt],
572
  src_lang=chat_request.src_lang,
@@ -575,15 +549,16 @@ async def chat(request: Request, chat_request: ChatRequest):
575
  prompt_to_process = translated_prompt[0]
576
  logger.info(f"Translated prompt to English: {prompt_to_process}")
577
  else:
 
578
  prompt_to_process = chat_request.prompt
579
- logger.info("Prompt already in English, no translation needed")
580
 
581
- # Generate response in English
582
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
583
- logger.info(f"Generated English response: {response}")
584
 
585
- # Translate response to target language if tgt_lang is not English
586
- if chat_request.tgt_lang != "eng_Latn":
587
  translated_response = await perform_internal_translation(
588
  sentences=[response],
589
  src_lang="eng_Latn",
@@ -592,8 +567,9 @@ async def chat(request: Request, chat_request: ChatRequest):
592
  final_response = translated_response[0]
593
  logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
594
  else:
 
595
  final_response = response
596
- logger.info("Response kept in English, no translation needed")
597
 
598
  return ChatResponse(response=final_response)
599
  except Exception as e:
@@ -612,8 +588,10 @@ async def visual_query(
612
  if image.size == (0, 0):
613
  raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
614
 
615
- # Translate query to English if src_lang is not English
616
- if src_lang != "eng_Latn":
 
 
617
  translated_query = await perform_internal_translation(
618
  sentences=[query],
619
  src_lang=src_lang,
@@ -623,14 +601,12 @@ async def visual_query(
623
  logger.info(f"Translated query to English: {query_to_process}")
624
  else:
625
  query_to_process = query
626
- logger.info("Query already in English, no translation needed")
627
 
628
- # Generate response in English
629
  answer = await llm_manager.vision_query(image, query_to_process)
630
  logger.info(f"Generated English answer: {answer}")
631
 
632
- # Translate answer to target language if tgt_lang is not English
633
- if tgt_lang != "eng_Latn":
634
  translated_answer = await perform_internal_translation(
635
  sentences=[answer],
636
  src_lang="eng_Latn",
@@ -640,7 +616,7 @@ async def visual_query(
640
  logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
641
  else:
642
  final_answer = answer
643
- logger.info("Answer kept in English, no translation needed")
644
 
645
  return {"answer": final_answer}
646
  except Exception as e:
@@ -664,14 +640,16 @@ async def chat_v2(
664
  logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
665
 
666
  try:
 
 
 
667
  if image:
668
  image_data = await image.read()
669
  if not image_data:
670
  raise HTTPException(status_code=400, detail="Uploaded image is empty")
671
  img = Image.open(io.BytesIO(image_data))
672
 
673
- # Translate prompt to English if src_lang is not English
674
- if src_lang != "eng_Latn":
675
  translated_prompt = await perform_internal_translation(
676
  sentences=[prompt],
677
  src_lang=src_lang,
@@ -681,13 +659,12 @@ async def chat_v2(
681
  logger.info(f"Translated prompt to English: {prompt_to_process}")
682
  else:
683
  prompt_to_process = prompt
684
- logger.info("Prompt already in English, no translation needed")
685
 
686
  decoded = await llm_manager.chat_v2(img, prompt_to_process)
687
- logger.info(f"Generated English response: {decoded}")
688
 
689
- # Translate response to target language if tgt_lang is not English
690
- if tgt_lang != "eng_Latn":
691
  translated_response = await perform_internal_translation(
692
  sentences=[decoded],
693
  src_lang="eng_Latn",
@@ -697,10 +674,9 @@ async def chat_v2(
697
  logger.info(f"Translated response to {tgt_lang}: {final_response}")
698
  else:
699
  final_response = decoded
700
- logger.info("Response kept in English, no translation needed")
701
  else:
702
- # Translate prompt to English if src_lang is not English
703
- if src_lang != "eng_Latn":
704
  translated_prompt = await perform_internal_translation(
705
  sentences=[prompt],
706
  src_lang=src_lang,
@@ -710,13 +686,12 @@ async def chat_v2(
710
  logger.info(f"Translated prompt to English: {prompt_to_process}")
711
  else:
712
  prompt_to_process = prompt
713
- logger.info("Prompt already in English, no translation needed")
714
 
715
  decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
716
- logger.info(f"Generated English response: {decoded}")
717
 
718
- # Translate response to target language if tgt_lang is not English
719
- if tgt_lang != "eng_Latn":
720
  translated_response = await perform_internal_translation(
721
  sentences=[decoded],
722
  src_lang="eng_Latn",
@@ -726,7 +701,7 @@ async def chat_v2(
726
  logger.info(f"Translated response to {tgt_lang}: {final_response}")
727
  else:
728
  final_response = decoded
729
- logger.info("Response kept in English, no translation needed")
730
 
731
  return ChatResponse(response=final_response)
732
  except Exception as e:
@@ -736,7 +711,6 @@ async def chat_v2(
736
  class TranscriptionResponse(BaseModel):
737
  text: str
738
 
739
-
740
  class ASRModelManager:
741
  def __init__(self, device_type="cuda"):
742
  self.device_type = device_type
@@ -748,54 +722,25 @@ class ASRModelManager:
748
  "telugu": "te", "urdu": "ur"
749
  }
750
 
751
-
752
- from fastapi import FastAPI, UploadFile
753
- import torch
754
- import torchaudio
755
- from transformers import AutoModel
756
- import argparse
757
- import uvicorn
758
- from pydantic import BaseModel
759
- from pydub import AudioSegment
760
- from fastapi import FastAPI, File, UploadFile, HTTPException, Query
761
- from fastapi.responses import RedirectResponse, JSONResponse
762
- from typing import List
763
-
764
- # Load the model
765
  model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
766
-
767
- asr_manager = ASRModelManager() # Load Kannada, Hindi, Tamil, Telugu, Malayalam
768
-
769
-
770
- #asr_manager = ASRModelManager(device_type="")
771
 
772
  @app.post("/transcribe/", response_model=TranscriptionResponse)
773
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
774
- # Load the uploaded audio file
775
  wav, sr = torchaudio.load(file.file)
776
  wav = torch.mean(wav, dim=0, keepdim=True)
777
 
778
- # Resample if necessary
779
- target_sample_rate = 16000 # Expected sample rate
780
  if sr != target_sample_rate:
781
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
782
  wav = resampler(wav)
783
 
784
- # Perform ASR with CTC decoding
785
- #transcription_ctc = model(wav, "kn", "ctc")
786
-
787
- # Perform ASR with RNNT decoding
788
  transcription_rnnt = model(wav, "kn", "rnnt")
789
-
790
  return JSONResponse(content={"text": transcription_rnnt})
791
 
792
-
793
-
794
  class BatchTranscriptionResponse(BaseModel):
795
  transcriptions: List[str]
796
 
797
-
798
-
799
  if __name__ == "__main__":
800
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
801
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
21
  from logging_config import logger
22
  from tts_config import SPEED, ResponseFormat, config as tts_config
23
  from gemma_llm import LLMManager
 
 
24
 
25
  import time
26
  from contextlib import asynccontextmanager
27
+ from typing import Annotated, Any, OrderedDict
28
  import zipfile
29
  import soundfile as sf
 
 
 
 
30
  import numpy as np
31
  from config import SPEED, ResponseFormat, config
 
 
 
 
 
 
 
32
 
33
  # Device setup
34
  if torch.cuda.is_available():
 
76
  tokenizer = AutoTokenizer.from_pretrained(model_name)
77
  description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
78
 
 
79
  if tokenizer.pad_token is None:
80
  tokenizer.pad_token = tokenizer.eos_token
81
  if description_tokenizer.pad_token is None:
82
  description_tokenizer.pad_token = description_tokenizer.eos_token
 
83
 
84
+ model.forward = torch.compile(model.forward, mode="default")
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
86
  warmup_inputs = tokenizer("Warmup text for compilation",
87
  return_tensors="pt",
88
  padding="max_length",
 
95
  "prompt_attention_mask": warmup_inputs["attention_mask"],
96
  }
97
 
98
+ for _ in range(1):
 
99
  _ = model.generate(**model_kwargs)
100
 
101
  logger.info(
 
122
  tts_model_manager.get_or_load_model(config.model)
123
  yield
124
 
 
125
  app = FastAPI(
126
  title="Dhwani API",
127
+ description="AI Chat API supporting multiple languages",
128
  version="1.0.0",
129
  redirect_slashes=False,
130
  lifespan=lifespan
131
  )
132
 
 
133
  def chunk_text(text, chunk_size):
134
  words = text.split()
135
  chunks = []
 
165
  padding="max_length",
166
  max_length=tts_model_manager.max_length).to(device)
167
 
 
168
  input_ids = desc_inputs["input_ids"]
169
  attention_mask = desc_inputs["attention_mask"]
170
  prompt_input_ids = prompt_inputs["input_ids"]
 
290
 
291
  return StreamingResponse(in_memory_zip, media_type="application/zip")
292
 
 
293
  # Supported language codes
294
  SUPPORTED_LANGUAGES = {
295
+ # Indian languages
296
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
297
  "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
298
  "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
299
  "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
300
+ "kan_Knda", "ory_Orya",
301
+ # European languages
302
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
303
+ "por_Latn", "rus_Cyrl", "pol_Latn"
304
+ }
305
+
306
+ # Define European languages for direct processing
307
+ EUROPEAN_LANGUAGES = {
308
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
309
+ "por_Latn", "rus_Cyrl", "pol_Latn"
310
  }
311
 
312
  class Settings(BaseSettings):
 
328
 
329
  settings = Settings()
330
 
 
331
  app.add_middleware(
332
  CORSMiddleware,
333
  allow_origins=["*"],
 
341
 
342
  llm_manager = LLMManager(settings.llm_model_name)
343
 
 
344
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
345
 
346
  class TranslateManager:
 
356
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
357
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
358
  else:
359
+ raise ValueError("Invalid language combination: English to English or European languages not supported here.")
360
 
361
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
362
  model = AutoModelForSeq2SeqLM.from_pretrained(
 
389
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
390
  key = 'indic_indic'
391
  else:
392
+ raise ValueError("Invalid language combination for translation.")
393
 
394
  if key not in self.models:
395
  if self.is_lazy_loading:
 
406
  ip = IndicProcessor(inference=True)
407
  model_manager = ModelManager()
408
 
 
409
  class ChatRequest(BaseModel):
410
  prompt: str
411
+ src_lang: str = "kan_Knda"
412
+ tgt_lang: str = "kan_Knda"
413
 
414
  @field_validator("prompt")
415
  def prompt_must_be_valid(cls, v):
 
434
  class TranslationResponse(BaseModel):
435
  translations: List[str]
436
 
 
437
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
438
  return model_manager.get_model(src_lang, tgt_lang)
439
 
 
440
  @app.post("/translate", response_model=TranslationResponse)
441
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
442
  input_sentences = request.sentences
 
476
  translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
477
  return TranslationResponse(translations=translations)
478
 
 
479
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
480
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
481
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
482
  response = await translate(request, translate_manager)
483
  return response.translations
484
 
 
485
  @app.get("/v1/health")
486
  async def health_check():
487
  return {"status": "healthy", "model": settings.llm_model_name}
 
533
  if not chat_request.prompt:
534
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
535
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
536
+
537
  try:
538
+ # Determine if the language requires translation (Indian languages only)
539
+ is_indian_language = chat_request.src_lang not in EUROPEAN_LANGUAGES and chat_request.src_lang != "eng_Latn"
540
+ is_target_indian = chat_request.tgt_lang not in EUROPEAN_LANGUAGES and chat_request.tgt_lang != "eng_Latn"
541
+
542
+ if is_indian_language:
543
+ # Translate prompt to English for Indian languages
544
  translated_prompt = await perform_internal_translation(
545
  sentences=[chat_request.prompt],
546
  src_lang=chat_request.src_lang,
 
549
  prompt_to_process = translated_prompt[0]
550
  logger.info(f"Translated prompt to English: {prompt_to_process}")
551
  else:
552
+ # Use prompt directly for English and European languages
553
  prompt_to_process = chat_request.prompt
554
+ logger.info("Prompt in English or European language, no translation needed")
555
 
556
+ # Generate response directly with the LLM
557
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
558
+ logger.info(f"Generated response: {response}")
559
 
560
+ if is_target_indian and chat_request.tgt_lang != "eng_Latn":
561
+ # Translate response to target Indian language
562
  translated_response = await perform_internal_translation(
563
  sentences=[response],
564
  src_lang="eng_Latn",
 
567
  final_response = translated_response[0]
568
  logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
569
  else:
570
+ # Keep response as-is for English and European languages
571
  final_response = response
572
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
573
 
574
  return ChatResponse(response=final_response)
575
  except Exception as e:
 
588
  if image.size == (0, 0):
589
  raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
590
 
591
+ is_indian_language = src_lang not in EUROPEAN_LANGUAGES and src_lang != "eng_Latn"
592
+ is_target_indian = tgt_lang not in EUROPEAN_LANGUAGES and tgt_lang != "eng_Latn"
593
+
594
+ if is_indian_language:
595
  translated_query = await perform_internal_translation(
596
  sentences=[query],
597
  src_lang=src_lang,
 
601
  logger.info(f"Translated query to English: {query_to_process}")
602
  else:
603
  query_to_process = query
604
+ logger.info("Query in English or European language, no translation needed")
605
 
 
606
  answer = await llm_manager.vision_query(image, query_to_process)
607
  logger.info(f"Generated English answer: {answer}")
608
 
609
+ if is_target_indian and tgt_lang != "eng_Latn":
 
610
  translated_answer = await perform_internal_translation(
611
  sentences=[answer],
612
  src_lang="eng_Latn",
 
616
  logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
617
  else:
618
  final_answer = answer
619
+ logger.info(f"Answer in {tgt_lang}, no translation needed")
620
 
621
  return {"answer": final_answer}
622
  except Exception as e:
 
640
  logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
641
 
642
  try:
643
+ is_indian_language = src_lang not in EUROPEAN_LANGUAGES and src_lang != "eng_Latn"
644
+ is_target_indian = tgt_lang not in EUROPEAN_LANGUAGES and tgt_lang != "eng_Latn"
645
+
646
  if image:
647
  image_data = await image.read()
648
  if not image_data:
649
  raise HTTPException(status_code=400, detail="Uploaded image is empty")
650
  img = Image.open(io.BytesIO(image_data))
651
 
652
+ if is_indian_language:
 
653
  translated_prompt = await perform_internal_translation(
654
  sentences=[prompt],
655
  src_lang=src_lang,
 
659
  logger.info(f"Translated prompt to English: {prompt_to_process}")
660
  else:
661
  prompt_to_process = prompt
662
+ logger.info("Prompt in English or European language, no translation needed")
663
 
664
  decoded = await llm_manager.chat_v2(img, prompt_to_process)
665
+ logger.info(f"Generated response: {decoded}")
666
 
667
+ if is_target_indian and tgt_lang != "eng_Latn":
 
668
  translated_response = await perform_internal_translation(
669
  sentences=[decoded],
670
  src_lang="eng_Latn",
 
674
  logger.info(f"Translated response to {tgt_lang}: {final_response}")
675
  else:
676
  final_response = decoded
677
+ logger.info(f"Response in {tgt_lang}, no translation needed")
678
  else:
679
+ if is_indian_language:
 
680
  translated_prompt = await perform_internal_translation(
681
  sentences=[prompt],
682
  src_lang=src_lang,
 
686
  logger.info(f"Translated prompt to English: {prompt_to_process}")
687
  else:
688
  prompt_to_process = prompt
689
+ logger.info("Prompt in English or European language, no translation needed")
690
 
691
  decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
692
+ logger.info(f"Generated response: {decoded}")
693
 
694
+ if is_target_indian and tgt_lang != "eng_Latn":
 
695
  translated_response = await perform_internal_translation(
696
  sentences=[decoded],
697
  src_lang="eng_Latn",
 
701
  logger.info(f"Translated response to {tgt_lang}: {final_response}")
702
  else:
703
  final_response = decoded
704
+ logger.info(f"Response in {tgt_lang}, no translation needed")
705
 
706
  return ChatResponse(response=final_response)
707
  except Exception as e:
 
711
  class TranscriptionResponse(BaseModel):
712
  text: str
713
 
 
714
  class ASRModelManager:
715
  def __init__(self, device_type="cuda"):
716
  self.device_type = device_type
 
722
  "telugu": "te", "urdu": "ur"
723
  }
724
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
726
+ asr_manager = ASRModelManager()
 
 
 
 
727
 
728
  @app.post("/transcribe/", response_model=TranscriptionResponse)
729
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
 
730
  wav, sr = torchaudio.load(file.file)
731
  wav = torch.mean(wav, dim=0, keepdim=True)
732
 
733
+ target_sample_rate = 16000
 
734
  if sr != target_sample_rate:
735
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
736
  wav = resampler(wav)
737
 
 
 
 
 
738
  transcription_rnnt = model(wav, "kn", "rnnt")
 
739
  return JSONResponse(content={"text": transcription_rnnt})
740
 
 
 
741
  class BatchTranscriptionResponse(BaseModel):
742
  transcriptions: List[str]
743
 
 
 
744
  if __name__ == "__main__":
745
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
746
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")