sachin commited on
Commit
b734e0b
·
1 Parent(s): 636e178
Files changed (1) hide show
  1. src/server/main.py +493 -468
src/server/main.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import io
3
  import os
4
  from time import time
5
- from typing import List, Dict
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
@@ -22,23 +22,25 @@ from contextlib import asynccontextmanager
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
25
- import logging
26
  from starlette.responses import StreamingResponse
27
- from logging_config import logger # Assumed external logging config
28
- from tts_config import SPEED, ResponseFormat, config as tts_config # Assumed external TTS config
29
  import torchaudio
30
- from tenacity import retry, stop_after_attempt, wait_exponential
31
- from torch.cuda.amp import autocast
32
 
33
  # Device setup
34
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
35
- torch_dtype = torch.float16 if device != "cpu" else torch.float32
36
- logger.info(f"Using device: {device} with dtype: {torch_dtype}")
 
 
 
 
37
 
38
  # Check CUDA availability and version
39
  cuda_available = torch.cuda.is_available()
40
  cuda_version = torch.version.cuda if cuda_available else None
41
- if cuda_available:
 
42
  device_idx = torch.cuda.current_device()
43
  capability = torch.cuda.get_device_capability(device_idx)
44
  compute_capability_float = float(f"{capability[0]}.{capability[1]}")
@@ -75,50 +77,33 @@ quantization_config = BitsAndBytesConfig(
75
  bnb_4bit_compute_dtype=torch.bfloat16
76
  )
77
 
78
- # Request queue for concurrency control
79
- request_queue = asyncio.Queue(maxsize=10)
80
-
81
- # Logging optimization
82
- logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
83
-
84
- # LLM Manager with batching
85
  class LLMManager:
86
- def __init__(self, model_name: str, device: str = device):
87
  self.model_name = model_name
88
  self.device = torch.device(device)
89
- self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
90
  self.model = None
91
  self.processor = None
92
  self.is_loaded = False
93
- self.token_cache = {}
94
- self.load()
95
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
96
 
97
  def load(self):
98
  if not self.is_loaded:
99
  try:
100
- if self.device.type == "cuda":
101
- torch.set_float32_matmul_precision('high')
102
- logger.info("Enabled TF32 matrix multiplication for improved GPU performance")
103
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
104
  self.model_name,
105
  device_map="auto",
106
  quantization_config=quantization_config,
107
  torch_dtype=self.torch_dtype
108
  )
109
- if self.model is None:
110
- raise ValueError(f"Failed to load model {self.model_name}: Model object is None")
111
  self.model.eval()
112
- self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
113
- dummy_input = self.processor("test", return_tensors="pt").to(self.device)
114
- with torch.no_grad():
115
- self.model.generate(**dummy_input, max_new_tokens=10)
116
  self.is_loaded = True
117
- logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
118
  except Exception as e:
119
  logger.error(f"Failed to load LLM: {str(e)}")
120
- self.is_loaded = False
121
- raise # Re-raise to ensure failure is caught upstream
122
 
123
  def unload(self):
124
  if self.is_loaded:
@@ -126,72 +111,74 @@ class LLMManager:
126
  del self.processor
127
  if self.device.type == "cuda":
128
  torch.cuda.empty_cache()
129
- logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
130
  self.is_loaded = False
131
- self.token_cache.clear()
132
- logger.info(f"LLM {self.model_name} unloaded")
133
 
134
- async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
135
  if not self.is_loaded:
136
- logger.warning("LLM not loaded; attempting reload")
137
  self.load()
138
- if not self.is_loaded:
139
- raise HTTPException(status_code=503, detail="LLM model unavailable")
140
-
141
- cache_key = f"{prompt}:{max_tokens}:{temperature}"
142
- if cache_key in self.token_cache:
143
- logger.info("Using cached response")
144
- return self.token_cache[cache_key]
145
 
146
- future = asyncio.Future()
147
- await request_queue.put({"prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, "future": future})
148
- response = await future
149
- self.token_cache[cache_key] = response
150
- logger.info(f"Generated response: {response}")
151
- return response
152
-
153
- async def batch_generate(self, prompts: List[Dict]) -> List[str]:
154
- messages_batch = [
155
- [
156
- {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
157
- {"role": "user", "content": [{"type": "text", "text": prompt["prompt"]}]}
158
- ]
159
- for prompt in prompts
160
  ]
 
161
  try:
162
  inputs_vlm = self.processor.apply_chat_template(
163
- messages_batch,
164
  add_generation_prompt=True,
165
  tokenize=True,
166
  return_dict=True,
167
- return_tensors="pt",
168
- padding=True
169
  ).to(self.device, dtype=torch.bfloat16)
170
- with autocast(), torch.no_grad():
171
- outputs = self.model.generate(
172
- **inputs_vlm,
173
- max_new_tokens=max(prompt["max_tokens"] for prompt in prompts),
174
- do_sample=True,
175
- top_p=0.9,
176
- temperature=max(prompt["temperature"] for prompt in prompts)
177
- )
178
- responses = [
179
- self.processor.decode(output[input_len:], skip_special_tokens=True)
180
- for output, input_len in zip(outputs, inputs_vlm["input_ids"].shape[1])
181
- ]
182
- logger.info(f"Batch generated {len(responses)} responses")
183
- return responses
184
  except Exception as e:
185
- logger.error(f"Error in batch generation: {str(e)}")
186
- raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  async def vision_query(self, image: Image.Image, query: str) -> str:
189
  if not self.is_loaded:
190
  self.load()
 
191
  messages_vlm = [
192
- {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]},
193
- {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
 
 
 
 
 
 
194
  ]
 
 
 
 
 
 
 
 
195
  try:
196
  inputs_vlm = self.processor.apply_chat_template(
197
  messages_vlm,
@@ -203,10 +190,18 @@ class LLMManager:
203
  except Exception as e:
204
  logger.error(f"Error in apply_chat_template: {str(e)}")
205
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
206
  input_len = inputs_vlm["input_ids"].shape[-1]
 
207
  with torch.inference_mode():
208
- generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
 
 
 
 
209
  generation = generation[0][input_len:]
 
210
  decoded = self.processor.decode(generation, skip_special_tokens=True)
211
  logger.info(f"Vision query response: {decoded}")
212
  return decoded
@@ -214,10 +209,25 @@ class LLMManager:
214
  async def chat_v2(self, image: Image.Image, query: str) -> str:
215
  if not self.is_loaded:
216
  self.load()
 
217
  messages_vlm = [
218
- {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
219
- {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
 
 
 
 
 
 
220
  ]
 
 
 
 
 
 
 
 
221
  try:
222
  inputs_vlm = self.processor.apply_chat_template(
223
  messages_vlm,
@@ -229,10 +239,18 @@ class LLMManager:
229
  except Exception as e:
230
  logger.error(f"Error in apply_chat_template: {str(e)}")
231
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
232
  input_len = inputs_vlm["input_ids"].shape[-1]
 
233
  with torch.inference_mode():
234
- generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
 
 
 
 
235
  generation = generation[0][input_len:]
 
236
  decoded = self.processor.decode(generation, skip_special_tokens=True)
237
  logger.info(f"Chat_v2 response: {decoded}")
238
  return decoded
@@ -240,42 +258,101 @@ class LLMManager:
240
  # TTS Manager
241
  class TTSManager:
242
  def __init__(self, device_type=device):
243
- self.device_type = torch.device(device_type)
244
  self.model = None
245
  self.repo_id = "ai4bharat/IndicF5"
246
- self.load()
247
 
248
  def load(self):
249
  if not self.model:
250
- logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
251
- self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
252
- logger.info("TTS model loaded")
253
-
254
- def unload(self):
255
- if self.model:
256
- del self.model
257
- if self.device_type.type == "cuda":
258
- torch.cuda.empty_cache()
259
- logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
260
- self.model = None
261
- logger.info("TTS model unloaded")
262
 
263
  def synthesize(self, text, ref_audio_path, ref_text):
264
  if not self.model:
265
  raise ValueError("TTS model not loaded")
266
- with autocast():
267
- return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  # Translation Manager
270
  class TranslateManager:
271
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
272
- self.device_type = torch.device(device_type)
273
  self.tokenizer = None
274
  self.model = None
275
  self.src_lang = src_lang
276
  self.tgt_lang = tgt_lang
277
  self.use_distilled = use_distilled
278
- self.load()
279
 
280
  def load(self):
281
  if not self.tokenizer or not self.model:
@@ -287,17 +364,21 @@ class TranslateManager:
287
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
288
  else:
289
  raise ValueError("Invalid language combination")
290
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
291
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
292
  model_name,
293
  trust_remote_code=True,
294
  torch_dtype=torch.float16,
295
  attn_implementation="flash_attention_2"
296
- ).to(self.device_type)
 
297
  self.model = torch.compile(self.model, mode="reduce-overhead")
298
  logger.info(f"Translation model {model_name} loaded")
299
 
300
- # Model Manager
301
  class ModelManager:
302
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
303
  self.models = {}
@@ -308,14 +389,18 @@ class ModelManager:
308
  def load_model(self, src_lang, tgt_lang, key):
309
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
310
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
 
311
  self.models[key] = translate_manager
312
  logger.info(f"Loaded translation model for {key}")
313
 
314
  def get_model(self, src_lang, tgt_lang):
315
  key = self._get_model_key(src_lang, tgt_lang)
316
- if key not in self.models and self.is_lazy_loading:
317
- self.load_model(src_lang, tgt_lang, key)
318
- return self.models.get(key) or (self.load_model(src_lang, tgt_lang, key) or self.models[key])
 
 
 
319
 
320
  def _get_model_key(self, src_lang, tgt_lang):
321
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
@@ -328,30 +413,21 @@ class ModelManager:
328
 
329
  # ASR Manager
330
  class ASRModelManager:
331
- def __init__(self, device_type=device):
332
- self.device_type = torch.device(device_type)
333
  self.model = None
334
  self.model_language = {"kannada": "kn"}
335
- self.load()
336
 
337
  def load(self):
338
  if not self.model:
339
- logger.info(f"Loading ASR model on {self.device_type}...")
340
  self.model = AutoModel.from_pretrained(
341
  "ai4bharat/indic-conformer-600m-multilingual",
342
  trust_remote_code=True
343
- ).to(self.device_type)
 
344
  logger.info("ASR model loaded")
345
 
346
- def unload(self):
347
- if self.model:
348
- del self.model
349
- if self.device_type.type == "cuda":
350
- torch.cuda.empty_cache()
351
- logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
352
- self.model = None
353
- logger.info("ASR model unloaded")
354
-
355
  # Global Managers
356
  llm_manager = LLMManager(settings.llm_model_name)
357
  model_manager = ModelManager()
@@ -359,31 +435,7 @@ asr_manager = ASRModelManager()
359
  tts_manager = TTSManager()
360
  ip = IndicProcessor(inference=True)
361
 
362
- # TTS Constants
363
- EXAMPLES = [
364
- {
365
- "audio_name": "KAN_F (Happy)",
366
- "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
367
- "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ।",
368
- "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
369
- },
370
- ]
371
-
372
  # Pydantic Models
373
- class SynthesizeRequest(BaseModel):
374
- text: str
375
- ref_audio_name: str
376
- ref_text: str = None
377
-
378
- class KannadaSynthesizeRequest(BaseModel):
379
- text: str
380
-
381
- @field_validator("text")
382
- def text_must_be_valid(cls, v):
383
- if len(v) > 500:
384
- raise ValueError("Text cannot exceed 500 characters")
385
- return v.strip()
386
-
387
  class ChatRequest(BaseModel):
388
  prompt: str
389
  src_lang: str = "kan_Knda"
@@ -401,6 +453,7 @@ class ChatRequest(BaseModel):
401
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
402
  return v
403
 
 
404
  class ChatResponse(BaseModel):
405
  response: str
406
 
@@ -415,149 +468,71 @@ class TranscriptionResponse(BaseModel):
415
  class TranslationResponse(BaseModel):
416
  translations: List[str]
417
 
418
- # TTS Functions
419
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
420
- def load_audio_from_url(url: str):
421
- response = requests.get(url)
422
- if response.status_code == 200:
423
- audio_data, sample_rate = sf.read(io.BytesIO(response.content))
424
- return sample_rate, audio_data
425
- raise HTTPException(status_code=500, detail="Failed to load reference audio from URL after retries")
426
-
427
- async def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str) -> io.BytesIO:
428
- async with request_queue:
429
- ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None)
430
- if not ref_audio_url:
431
- raise HTTPException(status_code=400, detail="Invalid reference audio name.")
432
- if not text.strip() or not ref_text.strip():
433
- raise HTTPException(status_code=400, detail="Text or reference text cannot be empty.")
434
-
435
- logger.info(f"Synthesizing speech for text: {text[:50]}... with ref_audio: {ref_audio_name}")
436
- loop = asyncio.get_running_loop()
437
- sample_rate, audio_data = await loop.run_in_executor(None, load_audio_from_url, ref_audio_url)
438
-
439
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio:
440
- await loop.run_in_executor(None, sf.write, temp_audio.name, audio_data, sample_rate, "WAV")
441
- temp_audio.flush()
442
- audio = tts_manager.synthesize(text, temp_audio.name, ref_text)
443
-
444
- buffer = io.BytesIO()
445
- await loop.run_in_executor(None, sf.write, buffer, audio.astype(np.float32) / 32768.0 if audio.dtype == np.int16 else audio, 24000, "WAV")
446
- buffer.seek(0)
447
- logger.info("Speech synthesis completed")
448
- return buffer
449
-
450
- # Supported Languages
451
- SUPPORTED_LANGUAGES = {
452
- "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
453
- "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
454
- "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
455
- "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
456
- "kan_Knda", "ory_Orya",
457
- "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
458
- "por_Latn", "rus_Cyrl", "pol_Latn"
459
- }
460
-
461
  # Dependency
462
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
463
  return model_manager.get_model(src_lang, tgt_lang)
464
 
465
- # Translation Function
466
- async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
467
- try:
468
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
469
- except ValueError as e:
470
- logger.info(f"Model not preloaded: {str(e)}, loading now...")
471
- key = model_manager._get_model_key(src_lang, tgt_lang)
472
- model_manager.load_model(src_lang, tgt_lang, key)
473
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
474
-
475
- batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
476
- inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
477
- with torch.no_grad(), autocast():
478
- generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
479
- with translate_manager.tokenizer.as_target_tokenizer():
480
- generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
481
- return ip.postprocess_batch(generated_tokens, lang=tgt_lang)
482
-
483
  # Lifespan Event Handler
484
  translation_configs = []
485
 
486
  @asynccontextmanager
487
  async def lifespan(app: FastAPI):
488
  def load_all_models():
489
- logger.info("Loading LLM model...")
490
- llm_manager.load()
491
- logger.info("Loading TTS model...")
492
- tts_manager.load()
493
- logger.info("Loading ASR model...")
494
- asr_manager.load()
495
- translation_tasks = [
496
- ('eng_Latn', 'kan_Knda', 'eng_indic'),
497
- ('kan_Knda', 'eng_Latn', 'indic_eng'),
498
- ('kan_Knda', 'hin_Deva', 'indic_indic'),
499
- ]
500
- for config in translation_configs:
501
- src_lang = config["src_lang"]
502
- tgt_lang = config["tgt_lang"]
503
- key = model_manager._get_model_key(src_lang, tgt_lang)
504
- translation_tasks.append((src_lang, tgt_lang, key))
505
- for src_lang, tgt_lang, key in translation_tasks:
506
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
507
- model_manager.load_model(src_lang, tgt_lang, key)
508
- logger.info("All models loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
- logger.info("Starting server with preloaded models...")
511
  load_all_models()
512
- batch_task = asyncio.create_task(batch_worker())
513
  yield
514
- batch_task.cancel()
515
  llm_manager.unload()
516
- tts_manager.unload()
517
- asr_manager.unload()
518
- for model in model_manager.models.values():
519
- model.unload()
520
- logger.info("Server shutdown complete; all models unloaded")
521
-
522
- # Batch Worker
523
- async def batch_worker():
524
- while True:
525
- batch = []
526
- last_request_time = time()
527
- try:
528
- while len(batch) < 4:
529
- try:
530
- request = await asyncio.wait_for(request_queue.get(), timeout=1.0)
531
- batch.append(request)
532
- current_time = time()
533
- if current_time - last_request_time > 1.0 and batch:
534
- break
535
- last_request_time = current_time
536
- except asyncio.TimeoutError:
537
- if batch:
538
- break
539
- continue
540
- if batch:
541
- start_time = time()
542
- responses = await llm_manager.batch_generate(batch)
543
- duration = time() - start_time
544
- logger.info(f"Batch of {len(batch)} requests processed in {duration:.3f} seconds")
545
- for request, response in zip(batch, responses):
546
- request["future"].set_result(response)
547
- except Exception as e:
548
- logger.error(f"Batch worker error: {str(e)}")
549
- for request in batch:
550
- request["future"].set_exception(e)
551
 
552
  # FastAPI App
553
  app = FastAPI(
554
- title="Optimized Dhwani API",
555
- description="AI Chat API supporting Indian languages with performance enhancements",
556
  version="1.0.0",
557
  redirect_slashes=False,
558
  lifespan=lifespan
559
  )
560
 
 
561
  app.add_middleware(
562
  CORSMiddleware,
563
  allow_origins=["*"],
@@ -566,11 +541,13 @@ app.add_middleware(
566
  allow_headers=["*"],
567
  )
568
 
 
569
  @app.middleware("http")
570
  async def add_request_timing(request: Request, call_next):
571
  start_time = time()
572
  response = await call_next(request)
573
- duration = time() - start_time
 
574
  logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
575
  response.headers["X-Response-Time"] = f"{duration:.3f}"
576
  return response
@@ -578,7 +555,7 @@ async def add_request_timing(request: Request, call_next):
578
  limiter = Limiter(key_func=get_remote_address)
579
  app.state.limiter = limiter
580
 
581
- # Endpoints
582
  @app.post("/v1/audio/speech", response_class=StreamingResponse)
583
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
584
  if not tts_manager.model:
@@ -586,78 +563,77 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
586
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
587
  if not request.text.strip():
588
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
589
- audio_buffer = await synthesize_speech(tts_manager, request.text, "KAN_F (Happy)", kannada_example["ref_text"])
 
 
 
 
 
 
 
590
  return StreamingResponse(
591
  audio_buffer,
592
  media_type="audio/wav",
593
  headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
594
  )
595
 
596
- @app.post("/v1/translate", response_model=TranslationResponse)
597
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
598
- if not request.sentences:
 
 
 
 
599
  raise HTTPException(status_code=400, detail="Input sentences are required")
600
- batch = ip.preprocess_batch(request.sentences, src_lang=request.src_lang, tgt_lang=request.tgt_lang)
601
- inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
602
- with torch.no_grad(), autocast():
603
- generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  with translate_manager.tokenizer.as_target_tokenizer():
605
- generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
606
- translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
 
 
 
 
 
607
  return TranslationResponse(translations=translations)
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  @app.get("/v1/health")
610
  async def health_check():
611
- memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0
612
- if memory_usage > 0.9:
613
- logger.warning("GPU memory usage exceeds 90%; consider unloading models")
614
- llm_status = "unhealthy"
615
- llm_latency = None
616
- if llm_manager.is_loaded:
617
- start = time()
618
- try:
619
- llm_test = await llm_manager.generate("What is the capital of Karnataka?", max_tokens=10)
620
- llm_latency = time() - start
621
- llm_status = "healthy" if llm_test else "unhealthy"
622
- except Exception as e:
623
- logger.error(f"LLM health check failed: {str(e)}")
624
- tts_status = "unhealthy"
625
- tts_latency = None
626
- if tts_manager.model:
627
- start = time()
628
- try:
629
- audio_buffer = await synthesize_speech(tts_manager, "Test", "KAN_F (Happy)", EXAMPLES[0]["ref_text"])
630
- tts_latency = time() - start
631
- tts_status = "healthy" if audio_buffer else "unhealthy"
632
- except Exception as e:
633
- logger.error(f"TTS health check failed: {str(e)}")
634
- asr_status = "unhealthy"
635
- asr_latency = None
636
- if asr_manager.model:
637
- start = time()
638
- try:
639
- dummy_audio = np.zeros(16000, dtype=np.float32)
640
- wav = torch.tensor(dummy_audio).unsqueeze(0).to(device)
641
- with autocast(), torch.no_grad():
642
- asr_test = asr_manager.model(wav, asr_manager.model_language["kannada"], "rnnt")
643
- asr_latency = time() - start
644
- asr_status = "healthy" if asr_test else "unhealthy"
645
- except Exception as e:
646
- logger.error(f"ASR health check failed: {str(e)}")
647
- status = {
648
- "status": "healthy" if llm_status == "healthy" and tts_status == "healthy" and asr_status == "healthy" else "degraded",
649
- "model": settings.llm_model_name,
650
- "llm_status": llm_status,
651
- "llm_latency": f"{llm_latency:.3f}s" if llm_latency else "N/A",
652
- "tts_status": tts_status,
653
- "tts_latency": f"{tts_latency:.3f}s" if tts_latency else "N/A",
654
- "asr_status": asr_status,
655
- "asr_latency": f"{asr_latency:.3f}s" if asr_latency else "N/A",
656
- "translation_models": list(model_manager.models.keys()),
657
- "gpu_memory_usage": f"{memory_usage:.2%}"
658
- }
659
- logger.info("Health check completed")
660
- return status
661
 
662
  @app.get("/")
663
  async def home():
@@ -668,10 +644,6 @@ async def unload_all_models():
668
  try:
669
  logger.info("Starting to unload all models...")
670
  llm_manager.unload()
671
- tts_manager.unload()
672
- asr_manager.unload()
673
- for model in model_manager.models.values():
674
- model.unload()
675
  logger.info("All models unloaded successfully")
676
  return {"status": "success", "message": "All models unloaded"}
677
  except Exception as e:
@@ -683,15 +655,6 @@ async def load_all_models():
683
  try:
684
  logger.info("Starting to load all models...")
685
  llm_manager.load()
686
- tts_manager.load()
687
- asr_manager.load()
688
- for src_lang, tgt_lang, key in [
689
- ('eng_Latn', 'kan_Knda', 'eng_indic'),
690
- ('kan_Knda', 'eng_Latn', 'indic_eng'),
691
- ('kan_Knda', 'hin_Deva', 'indic_indic'),
692
- ]:
693
- if key not in model_manager.models:
694
- model_manager.load_model(src_lang, tgt_lang, key)
695
  logger.info("All models loaded successfully")
696
  return {"status": "success", "message": "All models loaded"}
697
  except Exception as e:
@@ -702,7 +665,11 @@ async def load_all_models():
702
  async def translate_endpoint(request: TranslationRequest):
703
  logger.info(f"Received translation request: {request.dict()}")
704
  try:
705
- translations = await perform_internal_translation(request.sentences, request.src_lang, request.tgt_lang)
 
 
 
 
706
  logger.info(f"Translation successful: {translations}")
707
  return TranslationResponse(translations=translations)
708
  except Exception as e:
@@ -712,32 +679,44 @@ async def translate_endpoint(request: TranslationRequest):
712
  @app.post("/v1/chat", response_model=ChatResponse)
713
  @limiter.limit(settings.chat_rate_limit)
714
  async def chat(request: Request, chat_request: ChatRequest):
715
- async with request_queue:
716
- if not chat_request.prompt:
717
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
718
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
719
- EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
720
- try:
721
- if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
722
- translated_prompt = await perform_internal_translation([chat_request.prompt], chat_request.src_lang, "eng_Latn")
723
- prompt_to_process = translated_prompt[0]
724
- logger.info(f"Translated prompt to English: {prompt_to_process}")
725
- else:
726
- prompt_to_process = chat_request.prompt
727
- logger.info("Prompt in English or European language, no translation needed")
728
- response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
729
- logger.info(f"Generated English response: {response}")
730
- if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
731
- translated_response = await perform_internal_translation([response], "eng_Latn", chat_request.tgt_lang)
732
- final_response = translated_response[0]
733
- logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
734
- else:
735
- final_response = response
736
- logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
737
- return ChatResponse(response=final_response)
738
- except Exception as e:
739
- logger.error(f"Error processing request: {str(e)}")
740
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
741
 
742
  @app.post("/v1/visual_query/")
743
  async def visual_query(
@@ -746,31 +725,42 @@ async def visual_query(
746
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
747
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
748
  ):
749
- async with request_queue:
750
- try:
751
- image = Image.open(file.file)
752
- if image.size == (0, 0):
753
- raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
754
- if src_lang != "eng_Latn":
755
- translated_query = await perform_internal_translation([query], src_lang, "eng_Latn")
756
- query_to_process = translated_query[0]
757
- logger.info(f"Translated query to English: {query_to_process}")
758
- else:
759
- query_to_process = query
760
- logger.info("Query already in English, no translation needed")
761
- answer = await llm_manager.vision_query(image, query_to_process)
762
- logger.info(f"Generated English answer: {answer}")
763
- if tgt_lang != "eng_Latn":
764
- translated_answer = await perform_internal_translation([answer], "eng_Latn", tgt_lang)
765
- final_answer = translated_answer[0]
766
- logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
767
- else:
768
- final_answer = answer
769
- logger.info("Answer kept in English, no translation needed")
770
- return {"answer": final_answer}
771
- except Exception as e:
772
- logger.error(f"Error processing request: {str(e)}")
773
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
774
 
775
  @app.post("/v1/chat_v2", response_model=ChatResponse)
776
  @limiter.limit(settings.chat_rate_limit)
@@ -781,70 +771,95 @@ async def chat_v2(
781
  src_lang: str = Form("kan_Knda"),
782
  tgt_lang: str = Form("kan_Knda"),
783
  ):
784
- async with request_queue:
785
- if not prompt:
786
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
787
- if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
788
- raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
789
- logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
790
- try:
791
- if image:
792
- image_data = await image.read()
793
- if not image_data:
794
- raise HTTPException(status_code=400, detail="Uploaded image is empty")
795
- img = Image.open(io.BytesIO(image_data))
796
- if src_lang != "eng_Latn":
797
- translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
798
- prompt_to_process = translated_prompt[0]
799
- logger.info(f"Translated prompt to English: {prompt_to_process}")
800
- else:
801
- prompt_to_process = prompt
802
- decoded = await llm_manager.chat_v2(img, prompt_to_process)
803
- logger.info(f"Generated English response: {decoded}")
804
- if tgt_lang != "eng_Latn":
805
- translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
806
- final_response = translated_response[0]
807
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
808
- else:
809
- final_response = decoded
810
  else:
811
- if src_lang != "eng_Latn":
812
- translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
813
- prompt_to_process = translated_prompt[0]
814
- logger.info(f"Translated prompt to English: {prompt_to_process}")
815
- else:
816
- prompt_to_process = prompt
817
- decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
818
- logger.info(f"Generated English response: {decoded}")
819
- if tgt_lang != "eng_Latn":
820
- translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
821
- final_response = translated_response[0]
822
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
823
- else:
824
- final_response = decoded
825
- return ChatResponse(response=final_response)
826
- except Exception as e:
827
- logger.error(f"Error processing request: {str(e)}")
828
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
 
830
  @app.post("/v1/transcribe/", response_model=TranscriptionResponse)
831
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
832
- async with request_queue:
833
- if not asr_manager.model:
834
- raise HTTPException(status_code=503, detail="ASR model not loaded")
835
- try:
836
- wav, sr = torchaudio.load(file.file, backend="cuda" if cuda_available else "cpu")
837
- wav = torch.mean(wav, dim=0, keepdim=True).to(device)
838
- target_sample_rate = 16000
839
- if sr != target_sample_rate:
840
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
841
- wav = resampler(wav)
842
- with autocast(), torch.no_grad():
843
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
844
- return TranscriptionResponse(text=transcription_rnnt)
845
- except Exception as e:
846
- logger.error(f"Error in transcription: {str(e)}")
847
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
848
 
849
  @app.post("/v1/speech_to_speech")
850
  async def speech_to_speech(
@@ -852,20 +867,28 @@ async def speech_to_speech(
852
  file: UploadFile = File(...),
853
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
854
  ) -> StreamingResponse:
855
- async with request_queue:
856
- if not tts_manager.model:
857
- raise HTTPException(status_code=503, detail="TTS model not loaded")
858
- transcription = await transcribe_audio(file, language)
859
- logger.info(f"Transcribed text: {transcription.text}")
860
- chat_request = ChatRequest(prompt=transcription.text, src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"))
861
- processed_text = await chat(request, chat_request)
862
- logger.info(f"Processed text: {processed_text.response}")
863
- voice_request = KannadaSynthesizeRequest(text=processed_text.response)
864
- audio_response = await synthesize_kannada(voice_request)
865
- return audio_response
866
-
867
- LANGUAGE_TO_SCRIPT = {"kannada": "kan_Knda"}
868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  if __name__ == "__main__":
870
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
871
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
@@ -892,13 +915,15 @@ if __name__ == "__main__":
892
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
893
 
894
  llm_manager = LLMManager(settings.llm_model_name)
 
895
  if selected_config["components"]["ASR"]:
 
896
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
 
897
  if selected_config["components"]["Translation"]:
898
  translation_configs.extend(selected_config["components"]["Translation"])
899
 
900
  host = args.host if args.host != settings.host else settings.host
901
  port = args.port if args.port != settings.port else settings.port
902
 
903
- # Run Uvicorn with import string to support workers
904
- uvicorn.run("main:app", host=host, port=port, workers=2)
 
2
  import io
3
  import os
4
  from time import time
5
+ from typing import List
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
 
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
 
25
  from starlette.responses import StreamingResponse
26
+ from logging_config import logger
27
+ from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
 
 
29
 
30
  # Device setup
31
+ if torch.cuda.is_available():
32
+ device = "cuda:0"
33
+ logger.info("GPU will be used for inference")
34
+ else:
35
+ device = "cpu"
36
+ logger.info("CPU will be used for inference")
37
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
 
39
  # Check CUDA availability and version
40
  cuda_available = torch.cuda.is_available()
41
  cuda_version = torch.version.cuda if cuda_available else None
42
+
43
+ if torch.cuda.is_available():
44
  device_idx = torch.cuda.current_device()
45
  capability = torch.cuda.get_device_capability(device_idx)
46
  compute_capability_float = float(f"{capability[0]}.{capability[1]}")
 
77
  bnb_4bit_compute_dtype=torch.bfloat16
78
  )
79
 
80
+ # LLM Manager
 
 
 
 
 
 
81
  class LLMManager:
82
+ def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
  self.model_name = model_name
84
  self.device = torch.device(device)
85
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
87
  self.processor = None
88
  self.is_loaded = False
 
 
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
  def load(self):
92
  if not self.is_loaded:
93
  try:
 
 
 
94
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
95
  self.model_name,
96
  device_map="auto",
97
  quantization_config=quantization_config,
98
  torch_dtype=self.torch_dtype
99
  )
 
 
100
  self.model.eval()
101
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
 
 
 
102
  self.is_loaded = True
103
+ logger.info(f"LLM {self.model_name} loaded on {self.device}")
104
  except Exception as e:
105
  logger.error(f"Failed to load LLM: {str(e)}")
106
+ raise
 
107
 
108
  def unload(self):
109
  if self.is_loaded:
 
111
  del self.processor
112
  if self.device.type == "cuda":
113
  torch.cuda.empty_cache()
114
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
115
  self.is_loaded = False
116
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
 
117
 
118
+ async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
 
120
  self.load()
 
 
 
 
 
 
 
121
 
122
+ messages_vlm = [
123
+ {
124
+ "role": "system",
125
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
126
+ },
127
+ {
128
+ "role": "user",
129
+ "content": [{"type": "text", "text": prompt}]
130
+ }
 
 
 
 
 
131
  ]
132
+
133
  try:
134
  inputs_vlm = self.processor.apply_chat_template(
135
+ messages_vlm,
136
  add_generation_prompt=True,
137
  tokenize=True,
138
  return_dict=True,
139
+ return_tensors="pt"
 
140
  ).to(self.device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
+ logger.error(f"Error in tokenization: {str(e)}")
143
+ raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
144
+
145
+ input_len = inputs_vlm["input_ids"].shape[-1]
146
+
147
+ with torch.inference_mode():
148
+ generation = self.model.generate(
149
+ **inputs_vlm,
150
+ max_new_tokens=max_tokens,
151
+ do_sample=True,
152
+ temperature=temperature
153
+ )
154
+ generation = generation[0][input_len:]
155
+
156
+ response = self.processor.decode(generation, skip_special_tokens=True)
157
+ logger.info(f"Generated response: {response}")
158
+ return response
159
 
160
  async def vision_query(self, image: Image.Image, query: str) -> str:
161
  if not self.is_loaded:
162
  self.load()
163
+
164
  messages_vlm = [
165
+ {
166
+ "role": "system",
167
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
168
+ },
169
+ {
170
+ "role": "user",
171
+ "content": []
172
+ }
173
  ]
174
+
175
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
176
+ if image and image.size[0] > 0 and image.size[1] > 0:
177
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
178
+ logger.info(f"Received valid image for processing")
179
+ else:
180
+ logger.info("No valid image provided, processing text only")
181
+
182
  try:
183
  inputs_vlm = self.processor.apply_chat_template(
184
  messages_vlm,
 
190
  except Exception as e:
191
  logger.error(f"Error in apply_chat_template: {str(e)}")
192
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
193
+
194
  input_len = inputs_vlm["input_ids"].shape[-1]
195
+
196
  with torch.inference_mode():
197
+ generation = self.model.generate(
198
+ **inputs_vlm,
199
+ max_new_tokens=512,
200
+ do_sample=True,
201
+ temperature=0.7
202
+ )
203
  generation = generation[0][input_len:]
204
+
205
  decoded = self.processor.decode(generation, skip_special_tokens=True)
206
  logger.info(f"Vision query response: {decoded}")
207
  return decoded
 
209
  async def chat_v2(self, image: Image.Image, query: str) -> str:
210
  if not self.is_loaded:
211
  self.load()
212
+
213
  messages_vlm = [
214
+ {
215
+ "role": "system",
216
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
217
+ },
218
+ {
219
+ "role": "user",
220
+ "content": []
221
+ }
222
  ]
223
+
224
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
225
+ if image and image.size[0] > 0 and image.size[1] > 0:
226
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
227
+ logger.info(f"Received valid image for processing")
228
+ else:
229
+ logger.info("No valid image provided, processing text only")
230
+
231
  try:
232
  inputs_vlm = self.processor.apply_chat_template(
233
  messages_vlm,
 
239
  except Exception as e:
240
  logger.error(f"Error in apply_chat_template: {str(e)}")
241
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
242
+
243
  input_len = inputs_vlm["input_ids"].shape[-1]
244
+
245
  with torch.inference_mode():
246
+ generation = self.model.generate(
247
+ **inputs_vlm,
248
+ max_new_tokens=512,
249
+ do_sample=True,
250
+ temperature=0.7
251
+ )
252
  generation = generation[0][input_len:]
253
+
254
  decoded = self.processor.decode(generation, skip_special_tokens=True)
255
  logger.info(f"Chat_v2 response: {decoded}")
256
  return decoded
 
258
  # TTS Manager
259
  class TTSManager:
260
  def __init__(self, device_type=device):
261
+ self.device_type = device_type
262
  self.model = None
263
  self.repo_id = "ai4bharat/IndicF5"
 
264
 
265
  def load(self):
266
  if not self.model:
267
+ logger.info("Loading TTS model IndicF5...")
268
+ self.model = AutoModel.from_pretrained(
269
+ self.repo_id,
270
+ trust_remote_code=True
271
+ )
272
+ self.model = self.model.to(self.device_type)
273
+ logger.info("TTS model IndicF5 loaded")
 
 
 
 
 
274
 
275
  def synthesize(self, text, ref_audio_path, ref_text):
276
  if not self.model:
277
  raise ValueError("TTS model not loaded")
278
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
279
+
280
+ # TTS Constants
281
+ EXAMPLES = [
282
+ {
283
+ "audio_name": "KAN_F (Happy)",
284
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
285
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗ���ಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
286
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
287
+ },
288
+ ]
289
+
290
+ # Pydantic models for TTS
291
+ class SynthesizeRequest(BaseModel):
292
+ text: str
293
+ ref_audio_name: str
294
+ ref_text: str = None
295
+
296
+ class KannadaSynthesizeRequest(BaseModel):
297
+ text: str
298
+
299
+ # TTS Functions
300
+ def load_audio_from_url(url: str):
301
+ response = requests.get(url)
302
+ if response.status_code == 200:
303
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
304
+ return sample_rate, audio_data
305
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
306
+
307
+ def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
308
+ ref_audio_url = None
309
+ for example in EXAMPLES:
310
+ if example["audio_name"] == ref_audio_name:
311
+ ref_audio_url = example["audio_url"]
312
+ if not ref_text:
313
+ ref_text = example["ref_text"]
314
+ break
315
+
316
+ if not ref_audio_url:
317
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
318
+ if not text.strip():
319
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
320
+ if not ref_text or not ref_text.strip():
321
+ raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
322
+
323
+ sample_rate, audio_data = load_audio_from_url(ref_audio_url)
324
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
325
+ sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
326
+ temp_audio.flush()
327
+ audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
328
+
329
+ if audio.dtype == np.int16:
330
+ audio = audio.astype(np.float32) / 32768.0
331
+ buffer = io.BytesIO()
332
+ sf.write(buffer, audio, 24000, format='WAV')
333
+ buffer.seek(0)
334
+ return buffer
335
+
336
+ # Supported languages
337
+ SUPPORTED_LANGUAGES = {
338
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
339
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
340
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
341
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
342
+ "kan_Knda", "ory_Orya",
343
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
344
+ "por_Latn", "rus_Cyrl", "pol_Latn"
345
+ }
346
 
347
  # Translation Manager
348
  class TranslateManager:
349
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
350
+ self.device_type = device_type
351
  self.tokenizer = None
352
  self.model = None
353
  self.src_lang = src_lang
354
  self.tgt_lang = tgt_lang
355
  self.use_distilled = use_distilled
 
356
 
357
  def load(self):
358
  if not self.tokenizer or not self.model:
 
364
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
365
  else:
366
  raise ValueError("Invalid language combination")
367
+
368
+ self.tokenizer = AutoTokenizer.from_pretrained(
369
+ model_name,
370
+ trust_remote_code=True
371
+ )
372
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
373
  model_name,
374
  trust_remote_code=True,
375
  torch_dtype=torch.float16,
376
  attn_implementation="flash_attention_2"
377
+ )
378
+ self.model = self.model.to(self.device_type)
379
  self.model = torch.compile(self.model, mode="reduce-overhead")
380
  logger.info(f"Translation model {model_name} loaded")
381
 
 
382
  class ModelManager:
383
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
384
  self.models = {}
 
389
  def load_model(self, src_lang, tgt_lang, key):
390
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
391
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
392
+ translate_manager.load()
393
  self.models[key] = translate_manager
394
  logger.info(f"Loaded translation model for {key}")
395
 
396
  def get_model(self, src_lang, tgt_lang):
397
  key = self._get_model_key(src_lang, tgt_lang)
398
+ if key not in self.models:
399
+ if self.is_lazy_loading:
400
+ self.load_model(src_lang, tgt_lang, key)
401
+ else:
402
+ raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
403
+ return self.models.get(key)
404
 
405
  def _get_model_key(self, src_lang, tgt_lang):
406
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
 
413
 
414
  # ASR Manager
415
  class ASRModelManager:
416
+ def __init__(self, device_type="cuda"):
417
+ self.device_type = device_type
418
  self.model = None
419
  self.model_language = {"kannada": "kn"}
 
420
 
421
  def load(self):
422
  if not self.model:
423
+ logger.info("Loading ASR model...")
424
  self.model = AutoModel.from_pretrained(
425
  "ai4bharat/indic-conformer-600m-multilingual",
426
  trust_remote_code=True
427
+ )
428
+ self.model = self.model.to(self.device_type)
429
  logger.info("ASR model loaded")
430
 
 
 
 
 
 
 
 
 
 
431
  # Global Managers
432
  llm_manager = LLMManager(settings.llm_model_name)
433
  model_manager = ModelManager()
 
435
  tts_manager = TTSManager()
436
  ip = IndicProcessor(inference=True)
437
 
 
 
 
 
 
 
 
 
 
 
438
  # Pydantic Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  class ChatRequest(BaseModel):
440
  prompt: str
441
  src_lang: str = "kan_Knda"
 
453
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
454
  return v
455
 
456
+
457
  class ChatResponse(BaseModel):
458
  response: str
459
 
 
468
  class TranslationResponse(BaseModel):
469
  translations: List[str]
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  # Dependency
472
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
473
  return model_manager.get_model(src_lang, tgt_lang)
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  # Lifespan Event Handler
476
  translation_configs = []
477
 
478
  @asynccontextmanager
479
  async def lifespan(app: FastAPI):
480
  def load_all_models():
481
+ try:
482
+ # Load LLM model
483
+ logger.info("Loading LLM model...")
484
+ llm_manager.load()
485
+ logger.info("LLM model loaded successfully")
486
+
487
+ # Load TTS model
488
+ logger.info("Loading TTS model...")
489
+ tts_manager.load()
490
+ logger.info("TTS model loaded successfully")
491
+
492
+ # Load ASR model
493
+ logger.info("Loading ASR model...")
494
+ asr_manager.load()
495
+ logger.info("ASR model loaded successfully")
496
+
497
+ # Load translation models
498
+ translation_tasks = [
499
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
500
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
501
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
502
+ ]
503
+
504
+ for config in translation_configs:
505
+ src_lang = config["src_lang"]
506
+ tgt_lang = config["tgt_lang"]
507
+ key = model_manager._get_model_key(src_lang, tgt_lang)
508
+ translation_tasks.append((src_lang, tgt_lang, key))
509
+
510
+ for src_lang, tgt_lang, key in translation_tasks:
511
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
512
+ model_manager.load_model(src_lang, tgt_lang, key)
513
+ logger.info(f"Translation model for {key} loaded successfully")
514
+
515
+ logger.info("All models loaded successfully")
516
+ except Exception as e:
517
+ logger.error(f"Error loading models: {str(e)}")
518
+ raise
519
 
520
+ logger.info("Starting sequential model loading...")
521
  load_all_models()
 
522
  yield
 
523
  llm_manager.unload()
524
+ logger.info("Server shutdown complete")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  # FastAPI App
527
  app = FastAPI(
528
+ title="Dhwani API",
529
+ description="AI Chat API supporting Indian languages",
530
  version="1.0.0",
531
  redirect_slashes=False,
532
  lifespan=lifespan
533
  )
534
 
535
+ # Add CORS Middleware
536
  app.add_middleware(
537
  CORSMiddleware,
538
  allow_origins=["*"],
 
541
  allow_headers=["*"],
542
  )
543
 
544
+ # Add Timing Middleware
545
  @app.middleware("http")
546
  async def add_request_timing(request: Request, call_next):
547
  start_time = time()
548
  response = await call_next(request)
549
+ end_time = time()
550
+ duration = end_time - start_time
551
  logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
552
  response.headers["X-Response-Time"] = f"{duration:.3f}"
553
  return response
 
555
  limiter = Limiter(key_func=get_remote_address)
556
  app.state.limiter = limiter
557
 
558
+ # API Endpoints
559
  @app.post("/v1/audio/speech", response_class=StreamingResponse)
560
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
561
  if not tts_manager.model:
 
563
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
564
  if not request.text.strip():
565
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
566
+
567
+ audio_buffer = synthesize_speech(
568
+ tts_manager,
569
+ text=request.text,
570
+ ref_audio_name="KAN_F (Happy)",
571
+ ref_text=kannada_example["ref_text"]
572
+ )
573
+
574
  return StreamingResponse(
575
  audio_buffer,
576
  media_type="audio/wav",
577
  headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
578
  )
579
 
580
+ @app.post("/v0/translate", response_model=TranslationResponse)
581
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
582
+ input_sentences = request.sentences
583
+ src_lang = request.src_lang
584
+ tgt_lang = request.tgt_lang
585
+
586
+ if not input_sentences:
587
  raise HTTPException(status_code=400, detail="Input sentences are required")
588
+
589
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
590
+ inputs = translate_manager.tokenizer(
591
+ batch,
592
+ truncation=True,
593
+ padding="longest",
594
+ return_tensors="pt",
595
+ return_attention_mask=True,
596
+ ).to(translate_manager.device_type)
597
+
598
+ with torch.no_grad():
599
+ generated_tokens = translate_manager.model.generate(
600
+ **inputs,
601
+ use_cache=True,
602
+ min_length=0,
603
+ max_length=256,
604
+ num_beams=5,
605
+ num_return_sequences=1,
606
+ )
607
+
608
  with translate_manager.tokenizer.as_target_tokenizer():
609
+ generated_tokens = translate_manager.tokenizer.batch_decode(
610
+ generated_tokens.detach().cpu().tolist(),
611
+ skip_special_tokens=True,
612
+ clean_up_tokenization_spaces=True,
613
+ )
614
+
615
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
616
  return TranslationResponse(translations=translations)
617
 
618
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
619
+ try:
620
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
621
+ except ValueError as e:
622
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
623
+ key = model_manager._get_model_key(src_lang, tgt_lang)
624
+ model_manager.load_model(src_lang, tgt_lang, key)
625
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
626
+
627
+ if not translate_manager.model:
628
+ translate_manager.load()
629
+
630
+ request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
631
+ response = await translate(request, translate_manager)
632
+ return response.translations
633
+
634
  @app.get("/v1/health")
635
  async def health_check():
636
+ return {"status": "healthy", "model": settings.llm_model_name}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
 
638
  @app.get("/")
639
  async def home():
 
644
  try:
645
  logger.info("Starting to unload all models...")
646
  llm_manager.unload()
 
 
 
 
647
  logger.info("All models unloaded successfully")
648
  return {"status": "success", "message": "All models unloaded"}
649
  except Exception as e:
 
655
  try:
656
  logger.info("Starting to load all models...")
657
  llm_manager.load()
 
 
 
 
 
 
 
 
 
658
  logger.info("All models loaded successfully")
659
  return {"status": "success", "message": "All models loaded"}
660
  except Exception as e:
 
665
  async def translate_endpoint(request: TranslationRequest):
666
  logger.info(f"Received translation request: {request.dict()}")
667
  try:
668
+ translations = await perform_internal_translation(
669
+ sentences=request.sentences,
670
+ src_lang=request.src_lang,
671
+ tgt_lang=request.tgt_lang
672
+ )
673
  logger.info(f"Translation successful: {translations}")
674
  return TranslationResponse(translations=translations)
675
  except Exception as e:
 
679
  @app.post("/v1/chat", response_model=ChatResponse)
680
  @limiter.limit(settings.chat_rate_limit)
681
  async def chat(request: Request, chat_request: ChatRequest):
682
+ if not chat_request.prompt:
683
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
684
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
685
+
686
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
687
+
688
+ try:
689
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
690
+ translated_prompt = await perform_internal_translation(
691
+ sentences=[chat_request.prompt],
692
+ src_lang=chat_request.src_lang,
693
+ tgt_lang="eng_Latn"
694
+ )
695
+ prompt_to_process = translated_prompt[0]
696
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
697
+ else:
698
+ prompt_to_process = chat_request.prompt
699
+ logger.info("Prompt in English or European language, no translation needed")
700
+
701
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
702
+ logger.info(f"Generated response: {response}")
703
+
704
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
705
+ translated_response = await perform_internal_translation(
706
+ sentences=[response],
707
+ src_lang="eng_Latn",
708
+ tgt_lang=chat_request.tgt_lang
709
+ )
710
+ final_response = translated_response[0]
711
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
712
+ else:
713
+ final_response = response
714
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
715
+
716
+ return ChatResponse(response=final_response)
717
+ except Exception as e:
718
+ logger.error(f"Error processing request: {str(e)}")
719
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
720
 
721
  @app.post("/v1/visual_query/")
722
  async def visual_query(
 
725
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
726
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
727
  ):
728
+ try:
729
+ image = Image.open(file.file)
730
+ if image.size == (0, 0):
731
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
732
+
733
+ if src_lang != "eng_Latn":
734
+ translated_query = await perform_internal_translation(
735
+ sentences=[query],
736
+ src_lang=src_lang,
737
+ tgt_lang="eng_Latn"
738
+ )
739
+ query_to_process = translated_query[0]
740
+ logger.info(f"Translated query to English: {query_to_process}")
741
+ else:
742
+ query_to_process = query
743
+ logger.info("Query already in English, no translation needed")
744
+
745
+ answer = await llm_manager.vision_query(image, query_to_process)
746
+ logger.info(f"Generated English answer: {answer}")
747
+
748
+ if tgt_lang != "eng_Latn":
749
+ translated_answer = await perform_internal_translation(
750
+ sentences=[answer],
751
+ src_lang="eng_Latn",
752
+ tgt_lang=tgt_lang
753
+ )
754
+ final_answer = translated_answer[0]
755
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
756
+ else:
757
+ final_answer = answer
758
+ logger.info("Answer kept in English, no translation needed")
759
+
760
+ return {"answer": final_answer}
761
+ except Exception as e:
762
+ logger.error(f"Error processing request: {str(e)}")
763
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
764
 
765
  @app.post("/v1/chat_v2", response_model=ChatResponse)
766
  @limiter.limit(settings.chat_rate_limit)
 
771
  src_lang: str = Form("kan_Knda"),
772
  tgt_lang: str = Form("kan_Knda"),
773
  ):
774
+ if not prompt:
775
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
776
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
777
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
778
+
779
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
780
+
781
+ try:
782
+ if image:
783
+ image_data = await image.read()
784
+ if not image_data:
785
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
786
+ img = Image.open(io.BytesIO(image_data))
787
+
788
+ if src_lang != "eng_Latn":
789
+ translated_prompt = await perform_internal_translation(
790
+ sentences=[prompt],
791
+ src_lang=src_lang,
792
+ tgt_lang="eng_Latn"
793
+ )
794
+ prompt_to_process = translated_prompt[0]
795
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
 
 
 
 
796
  else:
797
+ prompt_to_process = prompt
798
+ logger.info("Prompt already in English, no translation needed")
799
+
800
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
801
+ logger.info(f"Generated English response: {decoded}")
802
+
803
+ if tgt_lang != "eng_Latn":
804
+ translated_response = await perform_internal_translation(
805
+ sentences=[decoded],
806
+ src_lang="eng_Latn",
807
+ tgt_lang=tgt_lang
808
+ )
809
+ final_response = translated_response[0]
810
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
811
+ else:
812
+ final_response = decoded
813
+ logger.info("Response kept in English, no translation needed")
814
+ else:
815
+ if src_lang != "eng_Latn":
816
+ translated_prompt = await perform_internal_translation(
817
+ sentences=[prompt],
818
+ src_lang=src_lang,
819
+ tgt_lang="eng_Latn"
820
+ )
821
+ prompt_to_process = translated_prompt[0]
822
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
823
+ else:
824
+ prompt_to_process = prompt
825
+ logger.info("Prompt already in English, no translation needed")
826
+
827
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
828
+ logger.info(f"Generated English response: {decoded}")
829
+
830
+ if tgt_lang != "eng_Latn":
831
+ translated_response = await perform_internal_translation(
832
+ sentences=[decoded],
833
+ src_lang="eng_Latn",
834
+ tgt_lang=tgt_lang
835
+ )
836
+ final_response = translated_response[0]
837
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
838
+ else:
839
+ final_response = decoded
840
+ logger.info("Response kept in English, no translation needed")
841
+
842
+ return ChatResponse(response=final_response)
843
+ except Exception as e:
844
+ logger.error(f"Error processing request: {str(e)}")
845
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
846
 
847
  @app.post("/v1/transcribe/", response_model=TranscriptionResponse)
848
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
849
+ if not asr_manager.model:
850
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
851
+ try:
852
+ wav, sr = torchaudio.load(file.file)
853
+ wav = torch.mean(wav, dim=0, keepdim=True)
854
+ target_sample_rate = 16000
855
+ if sr != target_sample_rate:
856
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
857
+ wav = resampler(wav)
858
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
859
+ return TranscriptionResponse(text=transcription_rnnt)
860
+ except Exception as e:
861
+ logger.error(f"Error in transcription: {str(e)}")
862
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
 
863
 
864
  @app.post("/v1/speech_to_speech")
865
  async def speech_to_speech(
 
867
  file: UploadFile = File(...),
868
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
869
  ) -> StreamingResponse:
870
+ if not tts_manager.model:
871
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
872
+ transcription = await transcribe_audio(file, language)
873
+ logger.info(f"Transcribed text: {transcription.text}")
 
 
 
 
 
 
 
 
 
874
 
875
+ chat_request = ChatRequest(
876
+ prompt=transcription.text,
877
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
878
+ tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
879
+ )
880
+ processed_text = await chat(request, chat_request)
881
+ logger.info(f"Processed text: {processed_text.response}")
882
+
883
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
884
+ audio_response = await synthesize_kannada(voice_request)
885
+ return audio_response
886
+
887
+ LANGUAGE_TO_SCRIPT = {
888
+ "kannada": "kan_Knda"
889
+ }
890
+
891
+ # Main Execution
892
  if __name__ == "__main__":
893
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
894
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
915
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
916
 
917
  llm_manager = LLMManager(settings.llm_model_name)
918
+
919
  if selected_config["components"]["ASR"]:
920
+ asr_model_name = selected_config["components"]["ASR"]["model"]
921
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
922
+
923
  if selected_config["components"]["Translation"]:
924
  translation_configs.extend(selected_config["components"]["Translation"])
925
 
926
  host = args.host if args.host != settings.host else settings.host
927
  port = args.port if args.port != settings.port else settings.port
928
 
929
+ uvicorn.run(app, host=host, port=port)