sachin commited on
Commit
94b0142
·
1 Parent(s): abca105

fix-endpoint

Browse files
Files changed (1) hide show
  1. src/server/main.py +114 -51
src/server/main.py CHANGED
@@ -21,14 +21,27 @@ from IndicTransToolkit import IndicProcessor
21
  from logging_config import logger
22
  from tts_config import SPEED, ResponseFormat, config as tts_config
23
  from gemma_llm import LLMManager
 
 
24
 
25
  import time
26
  from contextlib import asynccontextmanager
27
- from typing import Annotated, Any, OrderedDict
28
  import zipfile
29
  import soundfile as sf
 
 
 
 
30
  import numpy as np
31
  from config import SPEED, ResponseFormat, config
 
 
 
 
 
 
 
32
 
33
  # Device setup
34
  if torch.cuda.is_available():
@@ -76,13 +89,29 @@ class TTSModelManager:
76
  tokenizer = AutoTokenizer.from_pretrained(model_name)
77
  description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
78
 
 
79
  if tokenizer.pad_token is None:
80
  tokenizer.pad_token = tokenizer.eos_token
81
  if description_tokenizer.pad_token is None:
82
  description_tokenizer.pad_token = description_tokenizer.eos_token
 
83
 
84
- model.forward = torch.compile(model.forward, mode="default")
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
86
  warmup_inputs = tokenizer("Warmup text for compilation",
87
  return_tensors="pt",
88
  padding="max_length",
@@ -95,7 +124,8 @@ class TTSModelManager:
95
  "prompt_attention_mask": warmup_inputs["attention_mask"],
96
  }
97
 
98
- for _ in range(1):
 
99
  _ = model.generate(**model_kwargs)
100
 
101
  logger.info(
@@ -122,14 +152,16 @@ async def lifespan(_: FastAPI):
122
  tts_model_manager.get_or_load_model(config.model)
123
  yield
124
 
 
125
  app = FastAPI(
126
  title="Dhwani API",
127
- description="AI Chat API supporting multiple languages",
128
  version="1.0.0",
129
  redirect_slashes=False,
130
  lifespan=lifespan
131
  )
132
 
 
133
  def chunk_text(text, chunk_size):
134
  words = text.split()
135
  chunks = []
@@ -165,6 +197,7 @@ async def generate_audio(
165
  padding="max_length",
166
  max_length=tts_model_manager.max_length).to(device)
167
 
 
168
  input_ids = desc_inputs["input_ids"]
169
  attention_mask = desc_inputs["attention_mask"]
170
  prompt_input_ids = prompt_inputs["input_ids"]
@@ -290,23 +323,14 @@ async def generate_audio_batch(
290
 
291
  return StreamingResponse(in_memory_zip, media_type="application/zip")
292
 
 
293
  # Supported language codes
294
  SUPPORTED_LANGUAGES = {
295
- # Indian languages
296
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
297
  "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
298
  "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
299
  "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
300
- "kan_Knda", "ory_Orya",
301
- # European languages
302
- "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
303
- "por_Latn", "rus_Cyrl", "pol_Latn"
304
- }
305
-
306
- # Define European languages for direct processing
307
- EUROPEAN_LANGUAGES = {
308
- "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
309
- "por_Latn", "rus_Cyrl", "pol_Latn"
310
  }
311
 
312
  class Settings(BaseSettings):
@@ -328,6 +352,7 @@ class Settings(BaseSettings):
328
 
329
  settings = Settings()
330
 
 
331
  app.add_middleware(
332
  CORSMiddleware,
333
  allow_origins=["*"],
@@ -341,6 +366,7 @@ app.state.limiter = limiter
341
 
342
  llm_manager = LLMManager(settings.llm_model_name)
343
 
 
344
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
345
 
346
  class TranslateManager:
@@ -356,7 +382,7 @@ class TranslateManager:
356
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
357
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
358
  else:
359
- raise ValueError("Invalid language combination: English to English or European languages not supported here.")
360
 
361
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
362
  model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -389,7 +415,7 @@ class ModelManager:
389
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
390
  key = 'indic_indic'
391
  else:
392
- raise ValueError("Invalid language combination for translation.")
393
 
394
  if key not in self.models:
395
  if self.is_lazy_loading:
@@ -406,10 +432,11 @@ class ModelManager:
406
  ip = IndicProcessor(inference=True)
407
  model_manager = ModelManager()
408
 
 
409
  class ChatRequest(BaseModel):
410
  prompt: str
411
- src_lang: str = "kan_Knda"
412
- tgt_lang: str = "kan_Knda"
413
 
414
  @field_validator("prompt")
415
  def prompt_must_be_valid(cls, v):
@@ -434,9 +461,11 @@ class TranslationRequest(BaseModel):
434
  class TranslationResponse(BaseModel):
435
  translations: List[str]
436
 
 
437
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
438
  return model_manager.get_model(src_lang, tgt_lang)
439
 
 
440
  @app.post("/translate", response_model=TranslationResponse)
441
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
442
  input_sentences = request.sentences
@@ -476,12 +505,14 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
476
  translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
477
  return TranslationResponse(translations=translations)
478
 
 
479
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
480
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
481
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
482
  response = await translate(request, translate_manager)
483
  return response.translations
484
 
 
485
  @app.get("/v1/health")
486
  async def health_check():
487
  return {"status": "healthy", "model": settings.llm_model_name}
@@ -533,14 +564,14 @@ async def chat(request: Request, chat_request: ChatRequest):
533
  if not chat_request.prompt:
534
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
535
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
536
-
 
 
 
537
  try:
538
- # Determine if the language requires translation (Indian languages only)
539
- is_indian_language = chat_request.src_lang not in EUROPEAN_LANGUAGES and chat_request.src_lang != "eng_Latn"
540
- is_target_indian = chat_request.tgt_lang not in EUROPEAN_LANGUAGES and chat_request.tgt_lang != "eng_Latn"
541
-
542
- if is_indian_language:
543
- # Translate prompt to English for Indian languages
544
  translated_prompt = await perform_internal_translation(
545
  sentences=[chat_request.prompt],
546
  src_lang=chat_request.src_lang,
@@ -553,12 +584,13 @@ async def chat(request: Request, chat_request: ChatRequest):
553
  prompt_to_process = chat_request.prompt
554
  logger.info("Prompt in English or European language, no translation needed")
555
 
556
- # Generate response directly with the LLM
557
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
558
  logger.info(f"Generated response: {response}")
559
 
560
- if is_target_indian and chat_request.tgt_lang != "eng_Latn":
561
- # Translate response to target Indian language
 
562
  translated_response = await perform_internal_translation(
563
  sentences=[response],
564
  src_lang="eng_Latn",
@@ -588,10 +620,8 @@ async def visual_query(
588
  if image.size == (0, 0):
589
  raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
590
 
591
- is_indian_language = src_lang not in EUROPEAN_LANGUAGES and src_lang != "eng_Latn"
592
- is_target_indian = tgt_lang not in EUROPEAN_LANGUAGES and tgt_lang != "eng_Latn"
593
-
594
- if is_indian_language:
595
  translated_query = await perform_internal_translation(
596
  sentences=[query],
597
  src_lang=src_lang,
@@ -601,12 +631,14 @@ async def visual_query(
601
  logger.info(f"Translated query to English: {query_to_process}")
602
  else:
603
  query_to_process = query
604
- logger.info("Query in English or European language, no translation needed")
605
 
 
606
  answer = await llm_manager.vision_query(image, query_to_process)
607
  logger.info(f"Generated English answer: {answer}")
608
 
609
- if is_target_indian and tgt_lang != "eng_Latn":
 
610
  translated_answer = await perform_internal_translation(
611
  sentences=[answer],
612
  src_lang="eng_Latn",
@@ -616,7 +648,7 @@ async def visual_query(
616
  logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
617
  else:
618
  final_answer = answer
619
- logger.info(f"Answer in {tgt_lang}, no translation needed")
620
 
621
  return {"answer": final_answer}
622
  except Exception as e:
@@ -640,16 +672,14 @@ async def chat_v2(
640
  logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
641
 
642
  try:
643
- is_indian_language = src_lang not in EUROPEAN_LANGUAGES and src_lang != "eng_Latn"
644
- is_target_indian = tgt_lang not in EUROPEAN_LANGUAGES and tgt_lang != "eng_Latn"
645
-
646
  if image:
647
  image_data = await image.read()
648
  if not image_data:
649
  raise HTTPException(status_code=400, detail="Uploaded image is empty")
650
  img = Image.open(io.BytesIO(image_data))
651
 
652
- if is_indian_language:
 
653
  translated_prompt = await perform_internal_translation(
654
  sentences=[prompt],
655
  src_lang=src_lang,
@@ -659,12 +689,13 @@ async def chat_v2(
659
  logger.info(f"Translated prompt to English: {prompt_to_process}")
660
  else:
661
  prompt_to_process = prompt
662
- logger.info("Prompt in English or European language, no translation needed")
663
 
664
  decoded = await llm_manager.chat_v2(img, prompt_to_process)
665
- logger.info(f"Generated response: {decoded}")
666
 
667
- if is_target_indian and tgt_lang != "eng_Latn":
 
668
  translated_response = await perform_internal_translation(
669
  sentences=[decoded],
670
  src_lang="eng_Latn",
@@ -674,9 +705,10 @@ async def chat_v2(
674
  logger.info(f"Translated response to {tgt_lang}: {final_response}")
675
  else:
676
  final_response = decoded
677
- logger.info(f"Response in {tgt_lang}, no translation needed")
678
  else:
679
- if is_indian_language:
 
680
  translated_prompt = await perform_internal_translation(
681
  sentences=[prompt],
682
  src_lang=src_lang,
@@ -686,12 +718,13 @@ async def chat_v2(
686
  logger.info(f"Translated prompt to English: {prompt_to_process}")
687
  else:
688
  prompt_to_process = prompt
689
- logger.info("Prompt in English or European language, no translation needed")
690
 
691
  decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
692
- logger.info(f"Generated response: {decoded}")
693
 
694
- if is_target_indian and tgt_lang != "eng_Latn":
 
695
  translated_response = await perform_internal_translation(
696
  sentences=[decoded],
697
  src_lang="eng_Latn",
@@ -701,7 +734,7 @@ async def chat_v2(
701
  logger.info(f"Translated response to {tgt_lang}: {final_response}")
702
  else:
703
  final_response = decoded
704
- logger.info(f"Response in {tgt_lang}, no translation needed")
705
 
706
  return ChatResponse(response=final_response)
707
  except Exception as e:
@@ -711,6 +744,7 @@ async def chat_v2(
711
  class TranscriptionResponse(BaseModel):
712
  text: str
713
 
 
714
  class ASRModelManager:
715
  def __init__(self, device_type="cuda"):
716
  self.device_type = device_type
@@ -722,25 +756,54 @@ class ASRModelManager:
722
  "telugu": "te", "urdu": "ur"
723
  }
724
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
726
- asr_manager = ASRModelManager()
 
 
 
 
727
 
728
  @app.post("/transcribe/", response_model=TranscriptionResponse)
729
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
 
730
  wav, sr = torchaudio.load(file.file)
731
  wav = torch.mean(wav, dim=0, keepdim=True)
732
 
733
- target_sample_rate = 16000
 
734
  if sr != target_sample_rate:
735
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
736
  wav = resampler(wav)
737
 
 
 
 
 
738
  transcription_rnnt = model(wav, "kn", "rnnt")
 
739
  return JSONResponse(content={"text": transcription_rnnt})
740
 
 
 
741
  class BatchTranscriptionResponse(BaseModel):
742
  transcriptions: List[str]
743
 
 
 
744
  if __name__ == "__main__":
745
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
746
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
21
  from logging_config import logger
22
  from tts_config import SPEED, ResponseFormat, config as tts_config
23
  from gemma_llm import LLMManager
24
+ # from auth import get_api_key, settings as auth_settings
25
+
26
 
27
  import time
28
  from contextlib import asynccontextmanager
29
+ from typing import Annotated, Any, OrderedDict, List
30
  import zipfile
31
  import soundfile as sf
32
+ import torch
33
+ from fastapi import Body, FastAPI, HTTPException, Response
34
+ from parler_tts import ParlerTTSForConditionalGeneration
35
+ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
36
  import numpy as np
37
  from config import SPEED, ResponseFormat, config
38
+ from logger import logger
39
+ import uvicorn
40
+ import argparse
41
+ from fastapi.responses import RedirectResponse, StreamingResponse
42
+ import io
43
+ import os
44
+ import logging
45
 
46
  # Device setup
47
  if torch.cuda.is_available():
 
89
  tokenizer = AutoTokenizer.from_pretrained(model_name)
90
  description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
91
 
92
+ # Set pad tokens
93
  if tokenizer.pad_token is None:
94
  tokenizer.pad_token = tokenizer.eos_token
95
  if description_tokenizer.pad_token is None:
96
  description_tokenizer.pad_token = description_tokenizer.eos_token
97
+
98
 
99
+ # TODO - temporary disable -torch.compile
100
+ '''
101
+ # Update model configuration
102
+ model.config.pad_token_id = tokenizer.pad_token_id
103
+ # Update for deprecation: use max_batch_size instead of batch_size
104
+ if hasattr(model.generation_config.cache_config, 'max_batch_size'):
105
+ model.generation_config.cache_config.max_batch_size = 1
106
+ model.generation_config.cache_implementation = "static"
107
+ '''
108
+ # Compile the model
109
+ compile_mode = "default"
110
+ #compile_mode = "reduce-overhead"
111
+
112
+ model.forward = torch.compile(model.forward, mode=compile_mode)
113
 
114
+ # Warmup
115
  warmup_inputs = tokenizer("Warmup text for compilation",
116
  return_tensors="pt",
117
  padding="max_length",
 
124
  "prompt_attention_mask": warmup_inputs["attention_mask"],
125
  }
126
 
127
+ n_steps = 1 if compile_mode == "default" else 2
128
+ for _ in range(n_steps):
129
  _ = model.generate(**model_kwargs)
130
 
131
  logger.info(
 
152
  tts_model_manager.get_or_load_model(config.model)
153
  yield
154
 
155
+ #app = FastAPI(lifespan=lifespan)
156
  app = FastAPI(
157
  title="Dhwani API",
158
+ description="AI Chat API supporting Indian languages",
159
  version="1.0.0",
160
  redirect_slashes=False,
161
  lifespan=lifespan
162
  )
163
 
164
+
165
  def chunk_text(text, chunk_size):
166
  words = text.split()
167
  chunks = []
 
197
  padding="max_length",
198
  max_length=tts_model_manager.max_length).to(device)
199
 
200
+ # Use the tensor fields directly instead of BatchEncoding object
201
  input_ids = desc_inputs["input_ids"]
202
  attention_mask = desc_inputs["attention_mask"]
203
  prompt_input_ids = prompt_inputs["input_ids"]
 
323
 
324
  return StreamingResponse(in_memory_zip, media_type="application/zip")
325
 
326
+
327
  # Supported language codes
328
  SUPPORTED_LANGUAGES = {
 
329
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
330
  "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
331
  "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
332
  "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
333
+ "kan_Knda", "ory_Orya"
 
 
 
 
 
 
 
 
 
334
  }
335
 
336
  class Settings(BaseSettings):
 
352
 
353
  settings = Settings()
354
 
355
+
356
  app.add_middleware(
357
  CORSMiddleware,
358
  allow_origins=["*"],
 
366
 
367
  llm_manager = LLMManager(settings.llm_model_name)
368
 
369
+ # Translation Manager and Model Manager
370
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
371
 
372
  class TranslateManager:
 
382
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
383
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
384
  else:
385
+ raise ValueError("Invalid language combination: English to English translation is not supported.")
386
 
387
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
388
  model = AutoModelForSeq2SeqLM.from_pretrained(
 
415
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
416
  key = 'indic_indic'
417
  else:
418
+ raise ValueError("Invalid language combination: English to English translation is not supported.")
419
 
420
  if key not in self.models:
421
  if self.is_lazy_loading:
 
432
  ip = IndicProcessor(inference=True)
433
  model_manager = ModelManager()
434
 
435
+ # Pydantic Models
436
  class ChatRequest(BaseModel):
437
  prompt: str
438
+ src_lang: str = "kan_Knda" # Default to Kannada
439
+ tgt_lang: str = "kan_Knda" # Default to Kannada
440
 
441
  @field_validator("prompt")
442
  def prompt_must_be_valid(cls, v):
 
461
  class TranslationResponse(BaseModel):
462
  translations: List[str]
463
 
464
+ # Dependency to get TranslateManager
465
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
466
  return model_manager.get_model(src_lang, tgt_lang)
467
 
468
+ # Internal Translation Endpoint
469
  @app.post("/translate", response_model=TranslationResponse)
470
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
471
  input_sentences = request.sentences
 
505
  translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
506
  return TranslationResponse(translations=translations)
507
 
508
+ # Helper function to perform internal translation
509
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
510
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
511
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
512
  response = await translate(request, translate_manager)
513
  return response.translations
514
 
515
+ # API Endpoints
516
  @app.get("/v1/health")
517
  async def health_check():
518
  return {"status": "healthy", "model": settings.llm_model_name}
 
564
  if not chat_request.prompt:
565
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
566
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
567
+
568
+ # Define European languages that gemma-3-4b-it can handle natively
569
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
570
+
571
  try:
572
+ # Check if the source language is Indian (requires translation) or European/English (direct processing)
573
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
574
+ # Translate Indian language prompt to English
 
 
 
575
  translated_prompt = await perform_internal_translation(
576
  sentences=[chat_request.prompt],
577
  src_lang=chat_request.src_lang,
 
584
  prompt_to_process = chat_request.prompt
585
  logger.info("Prompt in English or European language, no translation needed")
586
 
587
+ # Generate response with the LLM (assumed to handle multilingual input natively)
588
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
589
  logger.info(f"Generated response: {response}")
590
 
591
+ # Check if the target language is Indian (requires translation) or European/English (direct output)
592
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
593
+ # Translate response to Indian target language
594
  translated_response = await perform_internal_translation(
595
  sentences=[response],
596
  src_lang="eng_Latn",
 
620
  if image.size == (0, 0):
621
  raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
622
 
623
+ # Translate query to English if src_lang is not English
624
+ if src_lang != "eng_Latn":
 
 
625
  translated_query = await perform_internal_translation(
626
  sentences=[query],
627
  src_lang=src_lang,
 
631
  logger.info(f"Translated query to English: {query_to_process}")
632
  else:
633
  query_to_process = query
634
+ logger.info("Query already in English, no translation needed")
635
 
636
+ # Generate response in English
637
  answer = await llm_manager.vision_query(image, query_to_process)
638
  logger.info(f"Generated English answer: {answer}")
639
 
640
+ # Translate answer to target language if tgt_lang is not English
641
+ if tgt_lang != "eng_Latn":
642
  translated_answer = await perform_internal_translation(
643
  sentences=[answer],
644
  src_lang="eng_Latn",
 
648
  logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
649
  else:
650
  final_answer = answer
651
+ logger.info("Answer kept in English, no translation needed")
652
 
653
  return {"answer": final_answer}
654
  except Exception as e:
 
672
  logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
673
 
674
  try:
 
 
 
675
  if image:
676
  image_data = await image.read()
677
  if not image_data:
678
  raise HTTPException(status_code=400, detail="Uploaded image is empty")
679
  img = Image.open(io.BytesIO(image_data))
680
 
681
+ # Translate prompt to English if src_lang is not English
682
+ if src_lang != "eng_Latn":
683
  translated_prompt = await perform_internal_translation(
684
  sentences=[prompt],
685
  src_lang=src_lang,
 
689
  logger.info(f"Translated prompt to English: {prompt_to_process}")
690
  else:
691
  prompt_to_process = prompt
692
+ logger.info("Prompt already in English, no translation needed")
693
 
694
  decoded = await llm_manager.chat_v2(img, prompt_to_process)
695
+ logger.info(f"Generated English response: {decoded}")
696
 
697
+ # Translate response to target language if tgt_lang is not English
698
+ if tgt_lang != "eng_Latn":
699
  translated_response = await perform_internal_translation(
700
  sentences=[decoded],
701
  src_lang="eng_Latn",
 
705
  logger.info(f"Translated response to {tgt_lang}: {final_response}")
706
  else:
707
  final_response = decoded
708
+ logger.info("Response kept in English, no translation needed")
709
  else:
710
+ # Translate prompt to English if src_lang is not English
711
+ if src_lang != "eng_Latn":
712
  translated_prompt = await perform_internal_translation(
713
  sentences=[prompt],
714
  src_lang=src_lang,
 
718
  logger.info(f"Translated prompt to English: {prompt_to_process}")
719
  else:
720
  prompt_to_process = prompt
721
+ logger.info("Prompt already in English, no translation needed")
722
 
723
  decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
724
+ logger.info(f"Generated English response: {decoded}")
725
 
726
+ # Translate response to target language if tgt_lang is not English
727
+ if tgt_lang != "eng_Latn":
728
  translated_response = await perform_internal_translation(
729
  sentences=[decoded],
730
  src_lang="eng_Latn",
 
734
  logger.info(f"Translated response to {tgt_lang}: {final_response}")
735
  else:
736
  final_response = decoded
737
+ logger.info("Response kept in English, no translation needed")
738
 
739
  return ChatResponse(response=final_response)
740
  except Exception as e:
 
744
  class TranscriptionResponse(BaseModel):
745
  text: str
746
 
747
+
748
  class ASRModelManager:
749
  def __init__(self, device_type="cuda"):
750
  self.device_type = device_type
 
756
  "telugu": "te", "urdu": "ur"
757
  }
758
 
759
+
760
+ from fastapi import FastAPI, UploadFile
761
+ import torch
762
+ import torchaudio
763
+ from transformers import AutoModel
764
+ import argparse
765
+ import uvicorn
766
+ from pydantic import BaseModel
767
+ from pydub import AudioSegment
768
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Query
769
+ from fastapi.responses import RedirectResponse, JSONResponse
770
+ from typing import List
771
+
772
+ # Load the model
773
  model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
774
+
775
+ asr_manager = ASRModelManager() # Load Kannada, Hindi, Tamil, Telugu, Malayalam
776
+
777
+
778
+ #asr_manager = ASRModelManager(device_type="")
779
 
780
  @app.post("/transcribe/", response_model=TranscriptionResponse)
781
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
782
+ # Load the uploaded audio file
783
  wav, sr = torchaudio.load(file.file)
784
  wav = torch.mean(wav, dim=0, keepdim=True)
785
 
786
+ # Resample if necessary
787
+ target_sample_rate = 16000 # Expected sample rate
788
  if sr != target_sample_rate:
789
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
790
  wav = resampler(wav)
791
 
792
+ # Perform ASR with CTC decoding
793
+ #transcription_ctc = model(wav, "kn", "ctc")
794
+
795
+ # Perform ASR with RNNT decoding
796
  transcription_rnnt = model(wav, "kn", "rnnt")
797
+
798
  return JSONResponse(content={"text": transcription_rnnt})
799
 
800
+
801
+
802
  class BatchTranscriptionResponse(BaseModel):
803
  transcriptions: List[str]
804
 
805
+
806
+
807
  if __name__ == "__main__":
808
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
809
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")