sachin commited on
Commit
7e9e8cc
·
1 Parent(s): fea8b58
Files changed (1) hide show
  1. src/server/main.py +87 -86
src/server/main.py CHANGED
@@ -28,26 +28,19 @@ from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
29
 
30
  # Device setup
31
- if torch.cuda.is_available():
32
- device = "cuda:0"
33
- logger.info("GPU will be used for inference")
34
- else:
35
- device = "cpu"
36
- logger.info("CPU will be used for inference")
37
  torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
 
38
 
39
  # Check CUDA availability and version
40
  cuda_available = torch.cuda.is_available()
41
  cuda_version = torch.version.cuda if cuda_available else None
42
-
43
- if torch.cuda.is_available():
44
  device_idx = torch.cuda.current_device()
45
  capability = torch.cuda.get_device_capability(device_idx)
46
- compute_capability_float = float(f"{capability[0]}.{capability[1]}")
47
- print(f"CUDA version: {cuda_version}")
48
- print(f"CUDA Compute Capability: {compute_capability_float}")
49
  else:
50
- print("CUDA is not available on this system.")
51
 
52
  # Settings
53
  class Settings(BaseSettings):
@@ -94,14 +87,7 @@ class LLMManager:
94
  try:
95
  if self.device.type == "cuda":
96
  torch.set_float32_matmul_precision('high')
97
- logger.info("Enabled TF32 matrix multiplication for improved performance")
98
-
99
- quantization_config = BitsAndBytesConfig(
100
- load_in_4bit=True,
101
- bnb_4bit_quant_type="nf4",
102
- bnb_4bit_compute_dtype=self.torch_dtype,
103
- bnb_4bit_use_double_quant=True
104
- )
105
 
106
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
107
  self.model_name,
@@ -113,7 +99,7 @@ class LLMManager:
113
 
114
  self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
115
  self.is_loaded = True
116
- logger.info(f"LLM {self.model_name} loaded on {self.device} with 4-bit quantization and fast processor")
117
  except Exception as e:
118
  logger.error(f"Failed to load LLM: {str(e)}")
119
  raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
@@ -124,10 +110,10 @@ class LLMManager:
124
  del self.processor
125
  if self.device.type == "cuda":
126
  torch.cuda.empty_cache()
127
- logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
128
  self.is_loaded = False
129
  self.token_cache.clear()
130
- logger.info(f"LLM {self.model_name} unloaded from {self.device}")
131
 
132
  async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
133
  if not self.is_loaded:
@@ -139,14 +125,8 @@ class LLMManager:
139
  return self.token_cache[cache_key]["response"]
140
 
141
  messages_vlm = [
142
- {
143
- "role": "system",
144
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
145
- },
146
- {
147
- "role": "user",
148
- "content": [{"type": "text", "text": prompt}]
149
- }
150
  ]
151
 
152
  try:
@@ -169,7 +149,7 @@ class LLMManager:
169
  input_len = inputs_vlm["input_ids"].shape[-1]
170
  adjusted_max_tokens = min(max_tokens, max(20, input_len * 2))
171
 
172
- with torch.inference_mode():
173
  generation = self.model.generate(
174
  **inputs_vlm,
175
  max_new_tokens=adjusted_max_tokens,
@@ -189,14 +169,8 @@ class LLMManager:
189
  self.load()
190
 
191
  messages_vlm = [
192
- {
193
- "role": "system",
194
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in one sentence maximum."}]
195
- },
196
- {
197
- "role": "user",
198
- "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])
199
- }
200
  ]
201
 
202
  cache_key = f"vision_{query}_{'image' if image else 'no_image'}"
@@ -224,7 +198,7 @@ class LLMManager:
224
  input_len = inputs_vlm["input_ids"].shape[-1]
225
  adjusted_max_tokens = min(512, max(20, input_len * 2))
226
 
227
- with torch.inference_mode():
228
  generation = self.model.generate(
229
  **inputs_vlm,
230
  max_new_tokens=adjusted_max_tokens,
@@ -244,14 +218,8 @@ class LLMManager:
244
  self.load()
245
 
246
  messages_vlm = [
247
- {
248
- "role": "system",
249
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
250
- },
251
- {
252
- "role": "user",
253
- "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])
254
- }
255
  ]
256
 
257
  cache_key = f"chat_v2_{query}_{'image' if image else 'no_image'}"
@@ -279,7 +247,7 @@ class LLMManager:
279
  input_len = inputs_vlm["input_ids"].shape[-1]
280
  adjusted_max_tokens = min(512, max(20, input_len * 2))
281
 
282
- with torch.inference_mode():
283
  generation = self.model.generate(
284
  **inputs_vlm,
285
  max_new_tokens=adjusted_max_tokens,
@@ -297,19 +265,24 @@ class LLMManager:
297
  # TTS Manager
298
  class TTSManager:
299
  def __init__(self, device_type=device):
300
- self.device_type = device_type
301
  self.model = None
302
  self.repo_id = "ai4bharat/IndicF5"
303
 
304
  def load(self):
305
  if not self.model:
306
- logger.info("Loading TTS model IndicF5...")
307
- self.model = AutoModel.from_pretrained(
308
- self.repo_id,
309
- trust_remote_code=True
310
- )
311
- self.model = self.model.to(self.device_type)
312
- logger.info("TTS model IndicF5 loaded")
 
 
 
 
 
313
 
314
  def synthesize(self, text, ref_audio_path, ref_text):
315
  if not self.model:
@@ -394,11 +367,11 @@ SUPPORTED_LANGUAGES = {
394
 
395
  # Translation Manager
396
  class TranslateManager:
397
- def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
398
- self.device_type = device_type
399
- self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
400
 
401
- def initialize_model(self, src_lang, tgt_lang, use_distilled):
402
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
403
  model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
404
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
@@ -417,6 +390,17 @@ class TranslateManager:
417
  ).to(self.device_type)
418
  return tokenizer, model
419
 
 
 
 
 
 
 
 
 
 
 
 
420
  class ModelManager:
421
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
422
  self.models: dict[str, TranslateManager] = {}
@@ -432,7 +416,7 @@ class ModelManager:
432
  ('kan_Knda', 'hin_Deva', 'indic_indic')
433
  ]
434
  for src_lang, tgt_lang, key in translation_pairs:
435
- logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang}...")
436
  self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
437
  logger.info(f"Translation model for {key} preloaded successfully")
438
 
@@ -452,21 +436,29 @@ class ModelManager:
452
 
453
  # ASR Manager
454
  class ASRModelManager:
455
- def __init__(self, device_type="cuda"):
456
- self.device_type = device_type
457
  self.model = None
458
  self.model_language = {"kannada": "kn"}
459
 
460
  def load(self):
461
  if not self.model:
462
- logger.info("Loading ASR model...")
463
  self.model = AutoModel.from_pretrained(
464
  "ai4bharat/indic-conformer-600m-multilingual",
465
  trust_remote_code=True
466
- )
467
- self.model = self.model.to(self.device_type)
468
  logger.info("ASR model loaded")
469
 
 
 
 
 
 
 
 
 
 
470
  # Global Managers
471
  llm_manager = LLMManager(settings.llm_model_name)
472
  model_manager = ModelManager()
@@ -552,15 +544,15 @@ translation_configs = []
552
  async def lifespan(app: FastAPI):
553
  def load_all_models():
554
  try:
555
- logger.info("Loading LLM model...")
556
  llm_manager.load()
557
  logger.info("LLM model loaded successfully")
558
 
559
- logger.info("Loading TTS model...")
560
  tts_manager.load()
561
  logger.info("TTS model loaded successfully")
562
 
563
- logger.info("Loading ASR model...")
564
  asr_manager.load()
565
  logger.info("ASR model loaded successfully")
566
 
@@ -574,7 +566,11 @@ async def lifespan(app: FastAPI):
574
  load_all_models()
575
  yield
576
  llm_manager.unload()
577
- logger.info("Server shutdown complete")
 
 
 
 
578
 
579
  # FastAPI App
580
  app = FastAPI(
@@ -585,7 +581,6 @@ app = FastAPI(
585
  lifespan=lifespan
586
  )
587
 
588
- # Add CORS Middleware
589
  app.add_middleware(
590
  CORSMiddleware,
591
  allow_origins=["*"],
@@ -594,7 +589,6 @@ app.add_middleware(
594
  allow_headers=["*"],
595
  )
596
 
597
- # Add Timing Middleware
598
  @app.middleware("http")
599
  async def add_request_timing(request: Request, call_next):
600
  start_time = time()
@@ -616,6 +610,10 @@ async def unload_all_models():
616
  try:
617
  logger.info("Starting to unload all models...")
618
  llm_manager.unload()
 
 
 
 
619
  logger.info("All models unloaded successfully")
620
  return {"status": "success", "message": "All models unloaded"}
621
  except Exception as e:
@@ -627,6 +625,8 @@ async def load_all_models():
627
  try:
628
  logger.info("Starting to load all models...")
629
  llm_manager.load()
 
 
630
  logger.info("All models loaded successfully")
631
  return {"status": "success", "message": "All models loaded"}
632
  except Exception as e:
@@ -775,10 +775,9 @@ async def chat_v2(
775
  logger.error(f"Error processing request: {str(e)}")
776
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
777
 
778
- # Include LLM Router
779
  app.include_router(llm_router)
780
 
781
- # Improved Endpoints
782
  @app.post("/audio/speech", response_class=StreamingResponse)
783
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
784
  if not tts_manager.model:
@@ -821,8 +820,11 @@ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(.
821
  if sr != target_sample_rate:
822
  logger.info(f"Resampling audio from {sr}Hz to {target_sample_rate}Hz")
823
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
824
- wav = resampler(wav)
825
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
 
 
 
826
  logger.info(f"Transcription completed: {transcription_rnnt[:50]}...")
827
  return TranscriptionResponse(text=transcription_rnnt)
828
  except Exception as e:
@@ -837,8 +839,11 @@ async def transcribe_step(audio_data: bytes, language: str) -> str:
837
  target_sample_rate = 16000
838
  if sr != target_sample_rate:
839
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
840
- wav = resampler(wav)
841
- return asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
 
 
 
842
 
843
  async def synthesize_step(text: str) -> io.BytesIO:
844
  kannada_example = next((ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)"), None)
@@ -863,11 +868,9 @@ async def speech_to_speech(
863
 
864
  logger.info(f"Processing speech-to-speech for file: {file.filename} in language: {language}")
865
  try:
866
- # Step 1: Transcribe
867
  transcription = await transcribe_step(audio_data, language)
868
  logger.info(f"Transcribed text: {transcription[:50]}...")
869
 
870
- # Step 2: Process with LLM
871
  chat_request = ChatRequest(
872
  prompt=transcription,
873
  src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
@@ -876,7 +879,6 @@ async def speech_to_speech(
876
  processed_text = await chat(request, chat_request)
877
  logger.info(f"Processed text: {processed_text.response[:50]}...")
878
 
879
- # Step 3: Synthesize
880
  audio_buffer = await synthesize_step(processed_text.response)
881
  logger.info("Speech-to-speech processing completed")
882
 
@@ -900,7 +902,8 @@ async def health_check():
900
  "translation_models": list(model_manager.models.keys()),
901
  "device": device,
902
  "cuda_available": cuda_available,
903
- "cuda_version": cuda_version if cuda_available else "N/A"
 
904
  }
905
  logger.info("Health check requested")
906
  return status
@@ -967,7 +970,6 @@ LANGUAGE_TO_SCRIPT = {
967
  "kannada": "kan_Knda"
968
  }
969
 
970
- # Main Execution
971
  if __name__ == "__main__":
972
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
973
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
@@ -996,7 +998,6 @@ if __name__ == "__main__":
996
  llm_manager = LLMManager(settings.llm_model_name)
997
 
998
  if selected_config["components"]["ASR"]:
999
- asr_model_name = selected_config["components"]["ASR"]["model"]
1000
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
1001
 
1002
  if selected_config["components"]["Translation"]:
 
28
  import torchaudio
29
 
30
  # Device setup
31
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
32
  torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
33
+ logger.info(f"Using device: {device} with dtype: {torch_dtype}")
34
 
35
  # Check CUDA availability and version
36
  cuda_available = torch.cuda.is_available()
37
  cuda_version = torch.version.cuda if cuda_available else None
38
+ if cuda_available:
 
39
  device_idx = torch.cuda.current_device()
40
  capability = torch.cuda.get_device_capability(device_idx)
41
+ logger.info(f"CUDA version: {cuda_version}, Compute Capability: {capability[0]}.{capability[1]}")
 
 
42
  else:
43
+ logger.info("CUDA is not available; falling back to CPU.")
44
 
45
  # Settings
46
  class Settings(BaseSettings):
 
87
  try:
88
  if self.device.type == "cuda":
89
  torch.set_float32_matmul_precision('high')
90
+ logger.info("Enabled TF32 matrix multiplication for improved GPU performance")
 
 
 
 
 
 
 
91
 
92
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
93
  self.model_name,
 
99
 
100
  self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
101
  self.is_loaded = True
102
+ logger.info(f"LLM {self.model_name} loaded on {self.device}")
103
  except Exception as e:
104
  logger.error(f"Failed to load LLM: {str(e)}")
105
  raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
 
110
  del self.processor
111
  if self.device.type == "cuda":
112
  torch.cuda.empty_cache()
113
+ logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
114
  self.is_loaded = False
115
  self.token_cache.clear()
116
+ logger.info(f"LLM {self.model_name} unloaded")
117
 
118
  async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
 
125
  return self.token_cache[cache_key]["response"]
126
 
127
  messages_vlm = [
128
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
129
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
 
 
 
 
 
 
130
  ]
131
 
132
  try:
 
149
  input_len = inputs_vlm["input_ids"].shape[-1]
150
  adjusted_max_tokens = min(max_tokens, max(20, input_len * 2))
151
 
152
+ with torch.no_grad():
153
  generation = self.model.generate(
154
  **inputs_vlm,
155
  max_new_tokens=adjusted_max_tokens,
 
169
  self.load()
170
 
171
  messages_vlm = [
172
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in one sentence maximum."}]},
173
+ {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])}
 
 
 
 
 
 
174
  ]
175
 
176
  cache_key = f"vision_{query}_{'image' if image else 'no_image'}"
 
198
  input_len = inputs_vlm["input_ids"].shape[-1]
199
  adjusted_max_tokens = min(512, max(20, input_len * 2))
200
 
201
+ with torch.no_grad():
202
  generation = self.model.generate(
203
  **inputs_vlm,
204
  max_new_tokens=adjusted_max_tokens,
 
218
  self.load()
219
 
220
  messages_vlm = [
221
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
222
+ {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])}
 
 
 
 
 
 
223
  ]
224
 
225
  cache_key = f"chat_v2_{query}_{'image' if image else 'no_image'}"
 
247
  input_len = inputs_vlm["input_ids"].shape[-1]
248
  adjusted_max_tokens = min(512, max(20, input_len * 2))
249
 
250
+ with torch.no_grad():
251
  generation = self.model.generate(
252
  **inputs_vlm,
253
  max_new_tokens=adjusted_max_tokens,
 
265
  # TTS Manager
266
  class TTSManager:
267
  def __init__(self, device_type=device):
268
+ self.device_type = torch.device(device_type)
269
  self.model = None
270
  self.repo_id = "ai4bharat/IndicF5"
271
 
272
  def load(self):
273
  if not self.model:
274
+ logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
275
+ self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
276
+ logger.info("TTS model loaded")
277
+
278
+ def unload(self):
279
+ if self.model:
280
+ del self.model
281
+ if self.device_type.type == "cuda":
282
+ torch.cuda.empty_cache()
283
+ logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
284
+ self.model = None
285
+ logger.info("TTS model unloaded")
286
 
287
  def synthesize(self, text, ref_audio_path, ref_text):
288
  if not self.model:
 
367
 
368
  # Translation Manager
369
  class TranslateManager:
370
+ def __init__(self, src_lang, tgt_lang, device_type=device):
371
+ self.device_type = torch.device(device_type)
372
+ self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang)
373
 
374
+ def initialize_model(self, src_lang, tgt_lang, use_distilled=True):
375
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
376
  model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
377
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
 
390
  ).to(self.device_type)
391
  return tokenizer, model
392
 
393
+ def unload(self):
394
+ if self.model:
395
+ del self.model
396
+ del self.tokenizer
397
+ if self.device_type.type == "cuda":
398
+ torch.cuda.empty_cache()
399
+ logger.info(f"Translation GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
400
+ self.model = None
401
+ self.tokenizer = None
402
+ logger.info("Translation model unloaded")
403
+
404
  class ModelManager:
405
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
406
  self.models: dict[str, TranslateManager] = {}
 
416
  ('kan_Knda', 'hin_Deva', 'indic_indic')
417
  ]
418
  for src_lang, tgt_lang, key in translation_pairs:
419
+ logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang} on {self.device_type}...")
420
  self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
421
  logger.info(f"Translation model for {key} preloaded successfully")
422
 
 
436
 
437
  # ASR Manager
438
  class ASRModelManager:
439
+ def __init__(self, device_type=device):
440
+ self.device_type = torch.device(device_type)
441
  self.model = None
442
  self.model_language = {"kannada": "kn"}
443
 
444
  def load(self):
445
  if not self.model:
446
+ logger.info(f"Loading ASR model on {self.device_type}...")
447
  self.model = AutoModel.from_pretrained(
448
  "ai4bharat/indic-conformer-600m-multilingual",
449
  trust_remote_code=True
450
+ ).to(self.device_type)
 
451
  logger.info("ASR model loaded")
452
 
453
+ def unload(self):
454
+ if self.model:
455
+ del self.model
456
+ if self.device_type.type == "cuda":
457
+ torch.cuda.empty_cache()
458
+ logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
459
+ self.model = None
460
+ logger.info("ASR model unloaded")
461
+
462
  # Global Managers
463
  llm_manager = LLMManager(settings.llm_model_name)
464
  model_manager = ModelManager()
 
544
  async def lifespan(app: FastAPI):
545
  def load_all_models():
546
  try:
547
+ logger.info(f"Loading LLM model on {device}...")
548
  llm_manager.load()
549
  logger.info("LLM model loaded successfully")
550
 
551
+ logger.info(f"Loading TTS model on {device}...")
552
  tts_manager.load()
553
  logger.info("TTS model loaded successfully")
554
 
555
+ logger.info(f"Loading ASR model on {device}...")
556
  asr_manager.load()
557
  logger.info("ASR model loaded successfully")
558
 
 
566
  load_all_models()
567
  yield
568
  llm_manager.unload()
569
+ tts_manager.unload()
570
+ asr_manager.unload()
571
+ for model in model_manager.models.values():
572
+ model.unload()
573
+ logger.info("Server shutdown complete; all models unloaded")
574
 
575
  # FastAPI App
576
  app = FastAPI(
 
581
  lifespan=lifespan
582
  )
583
 
 
584
  app.add_middleware(
585
  CORSMiddleware,
586
  allow_origins=["*"],
 
589
  allow_headers=["*"],
590
  )
591
 
 
592
  @app.middleware("http")
593
  async def add_request_timing(request: Request, call_next):
594
  start_time = time()
 
610
  try:
611
  logger.info("Starting to unload all models...")
612
  llm_manager.unload()
613
+ tts_manager.unload()
614
+ asr_manager.unload()
615
+ for model in model_manager.models.values():
616
+ model.unload()
617
  logger.info("All models unloaded successfully")
618
  return {"status": "success", "message": "All models unloaded"}
619
  except Exception as e:
 
625
  try:
626
  logger.info("Starting to load all models...")
627
  llm_manager.load()
628
+ tts_manager.load()
629
+ asr_manager.load()
630
  logger.info("All models loaded successfully")
631
  return {"status": "success", "message": "All models loaded"}
632
  except Exception as e:
 
775
  logger.error(f"Error processing request: {str(e)}")
776
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
777
 
 
778
  app.include_router(llm_router)
779
 
780
+ # Improved Endpoints with GPU Optimization
781
  @app.post("/audio/speech", response_class=StreamingResponse)
782
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
783
  if not tts_manager.model:
 
820
  if sr != target_sample_rate:
821
  logger.info(f"Resampling audio from {sr}Hz to {target_sample_rate}Hz")
822
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
823
+ wav = resampler(wav).to(device)
824
+ else:
825
+ wav = wav.to(device)
826
+ with torch.no_grad():
827
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
828
  logger.info(f"Transcription completed: {transcription_rnnt[:50]}...")
829
  return TranscriptionResponse(text=transcription_rnnt)
830
  except Exception as e:
 
839
  target_sample_rate = 16000
840
  if sr != target_sample_rate:
841
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
842
+ wav = resampler(wav).to(device)
843
+ else:
844
+ wav = wav.to(device)
845
+ with torch.no_grad():
846
+ return asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
847
 
848
  async def synthesize_step(text: str) -> io.BytesIO:
849
  kannada_example = next((ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)"), None)
 
868
 
869
  logger.info(f"Processing speech-to-speech for file: {file.filename} in language: {language}")
870
  try:
 
871
  transcription = await transcribe_step(audio_data, language)
872
  logger.info(f"Transcribed text: {transcription[:50]}...")
873
 
 
874
  chat_request = ChatRequest(
875
  prompt=transcription,
876
  src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
 
879
  processed_text = await chat(request, chat_request)
880
  logger.info(f"Processed text: {processed_text.response[:50]}...")
881
 
 
882
  audio_buffer = await synthesize_step(processed_text.response)
883
  logger.info("Speech-to-speech processing completed")
884
 
 
902
  "translation_models": list(model_manager.models.keys()),
903
  "device": device,
904
  "cuda_available": cuda_available,
905
+ "cuda_version": cuda_version if cuda_available else "N/A",
906
+ "gpu_memory_allocated": torch.cuda.memory_allocated() if cuda_available else 0
907
  }
908
  logger.info("Health check requested")
909
  return status
 
970
  "kannada": "kan_Knda"
971
  }
972
 
 
973
  if __name__ == "__main__":
974
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
975
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
998
  llm_manager = LLMManager(settings.llm_model_name)
999
 
1000
  if selected_config["components"]["ASR"]:
 
1001
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
1002
 
1003
  if selected_config["components"]["Translation"]: