sachin commited on
Commit
4ec2bfb
·
1 Parent(s): 20a258b
Dockerfile CHANGED
@@ -8,6 +8,8 @@ COPY . .
8
  RUN useradd -ms /bin/bash appuser \
9
  && chown -R appuser:appuser /app
10
  USER appuser
 
 
11
  ENV HF_HOME=/data/huggingface
12
  # Expose port
13
  EXPOSE 7860
 
8
  RUN useradd -ms /bin/bash appuser \
9
  && chown -R appuser:appuser /app
10
  USER appuser
11
+
12
+
13
  ENV HF_HOME=/data/huggingface
14
  # Expose port
15
  EXPOSE 7860
Dockerfile.app CHANGED
@@ -1,9 +1,7 @@
1
- # Use the pre-built image with models as the base
2
- FROM slabstech/dhwani-model-server:latest
3
  WORKDIR /app
4
 
5
- COPY dhwani_config.json .
6
- # Copy application code
7
  COPY . .
8
 
9
  # Set up user
@@ -11,6 +9,8 @@ RUN useradd -ms /bin/bash appuser \
11
  && chown -R appuser:appuser /app
12
  USER appuser
13
 
 
 
14
  # Expose port
15
  EXPOSE 7860
16
 
 
1
+ FROM slabstech/dhwani-server-base
 
2
  WORKDIR /app
3
 
4
+ ENV HF_HOME=/data/huggingface
 
5
  COPY . .
6
 
7
  # Set up user
 
9
  && chown -R appuser:appuser /app
10
  USER appuser
11
 
12
+
13
+ ENV HF_HOME=/data/huggingface
14
  # Expose port
15
  EXPOSE 7860
16
 
Dockerfile.models DELETED
@@ -1,17 +0,0 @@
1
- # Base image with CUDA support
2
- FROM slabstech/dhwani-server-base:latest
3
-
4
-
5
- # Create a directory for pre-downloaded models
6
- RUN mkdir -p /app/models
7
-
8
- # Define build argument for HF_TOKEN
9
- ARG HF_TOKEN_DOCKER
10
-
11
- # Set environment variable for the build process
12
- ENV HF_TOKEN=$HF_TOKEN_DOCKER
13
-
14
- # Copy and run the model download script
15
- COPY download_models.py .
16
- COPY . .
17
- RUN python download_models.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
download_models.py DELETED
@@ -1,35 +0,0 @@
1
- #!/usr/bin/env python3
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, AutoModel
3
- from transformers import Gemma3ForConditionalGeneration
4
- import os
5
-
6
- # Get the Hugging Face token from environment variable
7
- hf_token = os.getenv("HF_TOKEN")
8
- if not hf_token:
9
- print("Warning: HF_TOKEN not set. Some models may require authentication.")
10
-
11
- # Define the models to download
12
- models = {
13
- #'llm_model': ('google/gemma-3-4b-it', Gemma3ForConditionalGeneration, AutoProcessor),
14
- 'tts_model': ('ai4bharat/IndicF5', AutoModel, None),
15
- #'asr_model': ('ai4bharat/indic-conformer-600m-multilingual', AutoModel, None),
16
- 'trans_en_indic': ('ai4bharat/indictrans2-en-indic-dist-200M', AutoModelForSeq2SeqLM, AutoTokenizer),
17
- 'trans_indic_en': ('ai4bharat/indictrans2-indic-en-dist-200M', AutoModelForSeq2SeqLM, AutoTokenizer),
18
- 'trans_indic_indic': ('ai4bharat/indictrans2-indic-indic-dist-320M', AutoModelForSeq2SeqLM, AutoTokenizer),
19
- }
20
-
21
- # Directory to save models
22
- save_dir = '/app/models'
23
-
24
- # Ensure the directory exists
25
- os.makedirs(save_dir, exist_ok=True)
26
-
27
- # Download and save each model
28
- for name, (model_name, model_class, processor_class) in models.items():
29
- print(f'Downloading {model_name}...')
30
- model = model_class.from_pretrained(model_name, trust_remote_code=True, token=hf_token)
31
- model.save_pretrained(f'{save_dir}/{name}')
32
- if processor_class:
33
- processor = processor_class.from_pretrained(model_name, trust_remote_code=True, token=hf_token)
34
- processor.save_pretrained(f'{save_dir}/{name}')
35
- print(f'Saved {model_name} to {save_dir}/{name}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/server/main.py CHANGED
@@ -88,23 +88,19 @@ class LLMManager:
88
  self.is_loaded = False
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
- async def load(self):
92
  if not self.is_loaded:
93
  try:
94
- self.model = await asyncio.to_thread(
95
- Gemma3ForConditionalGeneration.from_pretrained,
96
  self.model_name,
97
  device_map="auto",
98
  quantization_config=quantization_config,
99
  torch_dtype=self.torch_dtype
100
  )
101
  self.model.eval()
102
- self.processor = await asyncio.to_thread(
103
- AutoProcessor.from_pretrained,
104
- self.model_name
105
- )
106
  self.is_loaded = True
107
- logger.info(f"LLM {self.model_name} loaded asynchronously on {self.device}")
108
  except Exception as e:
109
  logger.error(f"Failed to load LLM: {str(e)}")
110
  raise
@@ -121,7 +117,7 @@ class LLMManager:
121
 
122
  async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
123
  if not self.is_loaded:
124
- await self.load()
125
 
126
  messages_vlm = [
127
  {
@@ -163,7 +159,7 @@ class LLMManager:
163
 
164
  async def vision_query(self, image: Image.Image, query: str) -> str:
165
  if not self.is_loaded:
166
- await self.load()
167
 
168
  messages_vlm = [
169
  {
@@ -212,7 +208,7 @@ class LLMManager:
212
 
213
  async def chat_v2(self, image: Image.Image, query: str) -> str:
214
  if not self.is_loaded:
215
- await self.load()
216
 
217
  messages_vlm = [
218
  {
@@ -266,16 +262,15 @@ class TTSManager:
266
  self.model = None
267
  self.repo_id = "ai4bharat/IndicF5"
268
 
269
- async def load(self):
270
  if not self.model:
271
- logger.info("Loading TTS model IndicF5 asynchronously...")
272
- self.model = await asyncio.to_thread(
273
- AutoModel.from_pretrained,
274
  self.repo_id,
275
  trust_remote_code=True
276
  )
277
  self.model = self.model.to(self.device_type)
278
- logger.info("TTS model IndicF5 loaded asynchronously")
279
 
280
  def synthesize(self, text, ref_audio_path, ref_text):
281
  if not self.model:
@@ -359,7 +354,7 @@ class TranslateManager:
359
  self.tgt_lang = tgt_lang
360
  self.use_distilled = use_distilled
361
 
362
- async def load(self):
363
  if not self.tokenizer or not self.model:
364
  if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
365
  model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
@@ -370,13 +365,11 @@ class TranslateManager:
370
  else:
371
  raise ValueError("Invalid language combination")
372
 
373
- self.tokenizer = await asyncio.to_thread(
374
- AutoTokenizer.from_pretrained,
375
  model_name,
376
  trust_remote_code=True
377
  )
378
- self.model = await asyncio.to_thread(
379
- AutoModelForSeq2SeqLM.from_pretrained,
380
  model_name,
381
  trust_remote_code=True,
382
  torch_dtype=torch.float16,
@@ -384,7 +377,7 @@ class TranslateManager:
384
  )
385
  self.model = self.model.to(self.device_type)
386
  self.model = torch.compile(self.model, mode="reduce-overhead")
387
- logger.info(f"Translation model {model_name} loaded asynchronously")
388
 
389
  class ModelManager:
390
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
@@ -393,18 +386,18 @@ class ModelManager:
393
  self.use_distilled = use_distilled
394
  self.is_lazy_loading = is_lazy_loading
395
 
396
- async def load_model(self, src_lang, tgt_lang, key):
397
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} asynchronously")
398
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
399
- await translate_manager.load()
400
  self.models[key] = translate_manager
401
- logger.info(f"Loaded translation model for {key} asynchronously")
402
 
403
  def get_model(self, src_lang, tgt_lang):
404
  key = self._get_model_key(src_lang, tgt_lang)
405
  if key not in self.models:
406
  if self.is_lazy_loading:
407
- asyncio.create_task(self.load_model(src_lang, tgt_lang, key))
408
  else:
409
  raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
410
  return self.models.get(key)
@@ -425,16 +418,15 @@ class ASRModelManager:
425
  self.model = None
426
  self.model_language = {"kannada": "kn"}
427
 
428
- async def load(self):
429
  if not self.model:
430
- logger.info("Loading ASR model asynchronously...")
431
- self.model = await asyncio.to_thread(
432
- AutoModel.from_pretrained,
433
  "ai4bharat/indic-conformer-600m-multilingual",
434
  trust_remote_code=True
435
  )
436
  self.model = self.model.to(self.device_type)
437
- logger.info("ASR model loaded asynchronously")
438
 
439
  # Global Managers
440
  llm_manager = LLMManager(settings.llm_model_name)
@@ -484,34 +476,48 @@ translation_configs = []
484
 
485
  @asynccontextmanager
486
  async def lifespan(app: FastAPI):
487
- async def load_all_models():
488
  try:
489
- tasks = [
490
- llm_manager.load(),
491
- tts_manager.load(),
492
- asr_manager.load(),
493
- ]
494
-
 
 
 
 
 
 
 
 
 
 
495
  translation_tasks = [
496
- model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic'),
497
- model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng'),
498
- model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic'),
499
  ]
500
 
501
  for config in translation_configs:
502
  src_lang = config["src_lang"]
503
  tgt_lang = config["tgt_lang"]
504
  key = model_manager._get_model_key(src_lang, tgt_lang)
505
- translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
 
 
 
 
 
506
 
507
- await asyncio.gather(*tasks, *translation_tasks)
508
- logger.info("All models loaded successfully asynchronously")
509
  except Exception as e:
510
  logger.error(f"Error loading models: {str(e)}")
511
  raise
512
 
513
- logger.info("Starting asynchronous model loading...")
514
- await load_all_models()
515
  yield
516
  llm_manager.unload()
517
  logger.info("Server shutdown complete")
@@ -602,11 +608,11 @@ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_
602
  except ValueError as e:
603
  logger.info(f"Model not preloaded: {str(e)}, loading now...")
604
  key = model_manager._get_model_key(src_lang, tgt_lang)
605
- await model_manager.load_model(src_lang, tgt_lang, key)
606
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
607
 
608
  if not translate_manager.model:
609
- await translate_manager.load()
610
 
611
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
612
  response = await translate(request, translate_manager)
@@ -635,7 +641,7 @@ async def unload_all_models():
635
  async def load_all_models():
636
  try:
637
  logger.info("Starting to load all models...")
638
- await llm_manager.load()
639
  logger.info("All models loaded successfully")
640
  return {"status": "success", "message": "All models loaded"}
641
  except Exception as e:
 
88
  self.is_loaded = False
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
+ def load(self):
92
  if not self.is_loaded:
93
  try:
94
+ self.model = Gemma3ForConditionalGeneration.from_pretrained(
 
95
  self.model_name,
96
  device_map="auto",
97
  quantization_config=quantization_config,
98
  torch_dtype=self.torch_dtype
99
  )
100
  self.model.eval()
101
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
 
 
 
102
  self.is_loaded = True
103
+ logger.info(f"LLM {self.model_name} loaded on {self.device}")
104
  except Exception as e:
105
  logger.error(f"Failed to load LLM: {str(e)}")
106
  raise
 
117
 
118
  async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
120
+ self.load()
121
 
122
  messages_vlm = [
123
  {
 
159
 
160
  async def vision_query(self, image: Image.Image, query: str) -> str:
161
  if not self.is_loaded:
162
+ self.load()
163
 
164
  messages_vlm = [
165
  {
 
208
 
209
  async def chat_v2(self, image: Image.Image, query: str) -> str:
210
  if not self.is_loaded:
211
+ self.load()
212
 
213
  messages_vlm = [
214
  {
 
262
  self.model = None
263
  self.repo_id = "ai4bharat/IndicF5"
264
 
265
+ def load(self):
266
  if not self.model:
267
+ logger.info("Loading TTS model IndicF5...")
268
+ self.model = AutoModel.from_pretrained(
 
269
  self.repo_id,
270
  trust_remote_code=True
271
  )
272
  self.model = self.model.to(self.device_type)
273
+ logger.info("TTS model IndicF5 loaded")
274
 
275
  def synthesize(self, text, ref_audio_path, ref_text):
276
  if not self.model:
 
354
  self.tgt_lang = tgt_lang
355
  self.use_distilled = use_distilled
356
 
357
+ def load(self):
358
  if not self.tokenizer or not self.model:
359
  if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
360
  model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
 
365
  else:
366
  raise ValueError("Invalid language combination")
367
 
368
+ self.tokenizer = AutoTokenizer.from_pretrained(
 
369
  model_name,
370
  trust_remote_code=True
371
  )
372
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
 
373
  model_name,
374
  trust_remote_code=True,
375
  torch_dtype=torch.float16,
 
377
  )
378
  self.model = self.model.to(self.device_type)
379
  self.model = torch.compile(self.model, mode="reduce-overhead")
380
+ logger.info(f"Translation model {model_name} loaded")
381
 
382
  class ModelManager:
383
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
 
386
  self.use_distilled = use_distilled
387
  self.is_lazy_loading = is_lazy_loading
388
 
389
+ def load_model(self, src_lang, tgt_lang, key):
390
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
391
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
392
+ translate_manager.load()
393
  self.models[key] = translate_manager
394
+ logger.info(f"Loaded translation model for {key}")
395
 
396
  def get_model(self, src_lang, tgt_lang):
397
  key = self._get_model_key(src_lang, tgt_lang)
398
  if key not in self.models:
399
  if self.is_lazy_loading:
400
+ self.load_model(src_lang, tgt_lang, key)
401
  else:
402
  raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
403
  return self.models.get(key)
 
418
  self.model = None
419
  self.model_language = {"kannada": "kn"}
420
 
421
+ def load(self):
422
  if not self.model:
423
+ logger.info("Loading ASR model...")
424
+ self.model = AutoModel.from_pretrained(
 
425
  "ai4bharat/indic-conformer-600m-multilingual",
426
  trust_remote_code=True
427
  )
428
  self.model = self.model.to(self.device_type)
429
+ logger.info("ASR model loaded")
430
 
431
  # Global Managers
432
  llm_manager = LLMManager(settings.llm_model_name)
 
476
 
477
  @asynccontextmanager
478
  async def lifespan(app: FastAPI):
479
+ def load_all_models():
480
  try:
481
+ # Load LLM model
482
+ logger.info("Loading LLM model...")
483
+ llm_manager.load()
484
+ logger.info("LLM model loaded successfully")
485
+
486
+ # Load TTS model
487
+ logger.info("Loading TTS model...")
488
+ tts_manager.load()
489
+ logger.info("TTS model loaded successfully")
490
+
491
+ # Load ASR model
492
+ logger.info("Loading ASR model...")
493
+ asr_manager.load()
494
+ logger.info("ASR model loaded successfully")
495
+
496
+ # Load translation models
497
  translation_tasks = [
498
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
499
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
500
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
501
  ]
502
 
503
  for config in translation_configs:
504
  src_lang = config["src_lang"]
505
  tgt_lang = config["tgt_lang"]
506
  key = model_manager._get_model_key(src_lang, tgt_lang)
507
+ translation_tasks.append((src_lang, tgt_lang, key))
508
+
509
+ for src_lang, tgt_lang, key in translation_tasks:
510
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
511
+ model_manager.load_model(src_lang, tgt_lang, key)
512
+ logger.info(f"Translation model for {key} loaded successfully")
513
 
514
+ logger.info("All models loaded successfully")
 
515
  except Exception as e:
516
  logger.error(f"Error loading models: {str(e)}")
517
  raise
518
 
519
+ logger.info("Starting sequential model loading...")
520
+ load_all_models()
521
  yield
522
  llm_manager.unload()
523
  logger.info("Server shutdown complete")
 
608
  except ValueError as e:
609
  logger.info(f"Model not preloaded: {str(e)}, loading now...")
610
  key = model_manager._get_model_key(src_lang, tgt_lang)
611
+ model_manager.load_model(src_lang, tgt_lang, key)
612
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
613
 
614
  if not translate_manager.model:
615
+ translate_manager.load()
616
 
617
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
618
  response = await translate(request, translate_manager)
 
641
  async def load_all_models():
642
  try:
643
  logger.info("Starting to load all models...")
644
+ llm_manager.load()
645
  logger.info("All models loaded successfully")
646
  return {"status": "success", "message": "All models loaded"}
647
  except Exception as e:
src/server/main_hfy.py DELETED
@@ -1,910 +0,0 @@
1
- import argparse
2
- import io
3
- import os
4
- from time import time
5
- from typing import List
6
- import tempfile
7
- import uvicorn
8
- from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
- from PIL import Image
12
- from pydantic import BaseModel, field_validator
13
- from pydantic_settings import BaseSettings
14
- from slowapi import Limiter
15
- from slowapi.util import get_remote_address
16
- import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
18
- from IndicTransToolkit import IndicProcessor
19
- import json
20
- import asyncio
21
- from contextlib import asynccontextmanager
22
- import soundfile as sf
23
- import numpy as np
24
- import requests
25
- from starlette.responses import StreamingResponse
26
- from logging_config import logger
27
- from tts_config import SPEED, ResponseFormat, config as tts_config
28
- import torchaudio
29
-
30
- # Device setup
31
- if torch.cuda.is_available():
32
- device = "cuda:0"
33
- logger.info("GPU will be used for inference")
34
- else:
35
- device = "cpu"
36
- logger.info("CPU will be used for inference")
37
- torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
-
39
- # Check CUDA availability and version
40
- cuda_available = torch.cuda.is_available()
41
- cuda_version = torch.version.cuda if cuda_available else None
42
-
43
- if torch.cuda.is_available():
44
- device_idx = torch.cuda.current_device()
45
- capability = torch.cuda.get_device_capability(device_idx)
46
- compute_capability_float = float(f"{capability[0]}.{capability[1]}")
47
- print(f"CUDA version: {cuda_version}")
48
- print(f"CUDA Compute Capability: {compute_capability_float}")
49
- else:
50
- print("CUDA is not available on this system.")
51
-
52
- # Settings
53
- class Settings(BaseSettings):
54
- llm_model_name: str = "google/gemma-3-4b-it"
55
- max_tokens: int = 512
56
- host: str = "0.0.0.0"
57
- port: int = 7860
58
- chat_rate_limit: str = "100/minute"
59
- speech_rate_limit: str = "5/minute"
60
-
61
- @field_validator("chat_rate_limit", "speech_rate_limit")
62
- def validate_rate_limit(cls, v):
63
- if not v.count("/") == 1 or not v.split("/")[0].isdigit():
64
- raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
65
- return v
66
-
67
- class Config:
68
- env_file = ".env"
69
-
70
- settings = Settings()
71
-
72
- # Quantization config for LLM
73
- quantization_config = BitsAndBytesConfig(
74
- load_in_4bit=True,
75
- bnb_4bit_quant_type="nf4",
76
- bnb_4bit_use_double_quant=True,
77
- bnb_4bit_compute_dtype=torch.bfloat16
78
- )
79
-
80
- # LLM Manager
81
- class LLMManager:
82
- def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
- self.model_name = model_name
84
- self.device = torch.device(device)
85
- self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
- self.model = None
87
- self.processor = None
88
- self.is_loaded = False
89
- logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
-
91
- async def load(self):
92
- if not self.is_loaded:
93
- try:
94
- self.model = await asyncio.to_thread(
95
- Gemma3ForConditionalGeneration.from_pretrained,
96
- self.model_name,
97
- device_map="auto",
98
- quantization_config=quantization_config,
99
- torch_dtype=self.torch_dtype
100
- )
101
- self.model.eval()
102
- self.processor = await asyncio.to_thread(
103
- AutoProcessor.from_pretrained,
104
- self.model_name
105
- )
106
- self.is_loaded = True
107
- logger.info(f"LLM {self.model_name} loaded asynchronously on {self.device}")
108
- except Exception as e:
109
- logger.error(f"Failed to load LLM: {str(e)}")
110
- raise
111
-
112
- def unload(self):
113
- if self.is_loaded:
114
- del self.model
115
- del self.processor
116
- if self.device.type == "cuda":
117
- torch.cuda.empty_cache()
118
- logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
119
- self.is_loaded = False
120
- logger.info(f"LLM {self.model_name} unloaded from {self.device}")
121
-
122
- async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
123
- if not self.is_loaded:
124
- await self.load()
125
-
126
- messages_vlm = [
127
- {
128
- "role": "system",
129
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
130
- },
131
- {
132
- "role": "user",
133
- "content": [{"type": "text", "text": prompt}]
134
- }
135
- ]
136
-
137
- try:
138
- inputs_vlm = self.processor.apply_chat_template(
139
- messages_vlm,
140
- add_generation_prompt=True,
141
- tokenize=True,
142
- return_dict=True,
143
- return_tensors="pt"
144
- ).to(self.device, dtype=torch.bfloat16)
145
- except Exception as e:
146
- logger.error(f"Error in tokenization: {str(e)}")
147
- raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
148
-
149
- input_len = inputs_vlm["input_ids"].shape[-1]
150
-
151
- with torch.inference_mode():
152
- generation = self.model.generate(
153
- **inputs_vlm,
154
- max_new_tokens=max_tokens,
155
- do_sample=True,
156
- temperature=temperature
157
- )
158
- generation = generation[0][input_len:]
159
-
160
- response = self.processor.decode(generation, skip_special_tokens=True)
161
- logger.info(f"Generated response: {response}")
162
- return response
163
-
164
- async def vision_query(self, image: Image.Image, query: str) -> str:
165
- if not self.is_loaded:
166
- await self.load()
167
-
168
- messages_vlm = [
169
- {
170
- "role": "system",
171
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
172
- },
173
- {
174
- "role": "user",
175
- "content": []
176
- }
177
- ]
178
-
179
- messages_vlm[1]["content"].append({"type": "text", "text": query})
180
- if image and image.size[0] > 0 and image.size[1] > 0:
181
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
182
- logger.info(f"Received valid image for processing")
183
- else:
184
- logger.info("No valid image provided, processing text only")
185
-
186
- try:
187
- inputs_vlm = self.processor.apply_chat_template(
188
- messages_vlm,
189
- add_generation_prompt=True,
190
- tokenize=True,
191
- return_dict=True,
192
- return_tensors="pt"
193
- ).to(self.device, dtype=torch.bfloat16)
194
- except Exception as e:
195
- logger.error(f"Error in apply_chat_template: {str(e)}")
196
- raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
197
-
198
- input_len = inputs_vlm["input_ids"].shape[-1]
199
-
200
- with torch.inference_mode():
201
- generation = self.model.generate(
202
- **inputs_vlm,
203
- max_new_tokens=512,
204
- do_sample=True,
205
- temperature=0.7
206
- )
207
- generation = generation[0][input_len:]
208
-
209
- decoded = self.processor.decode(generation, skip_special_tokens=True)
210
- logger.info(f"Vision query response: {decoded}")
211
- return decoded
212
-
213
- async def chat_v2(self, image: Image.Image, query: str) -> str:
214
- if not self.is_loaded:
215
- await self.load()
216
-
217
- messages_vlm = [
218
- {
219
- "role": "system",
220
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
221
- },
222
- {
223
- "role": "user",
224
- "content": []
225
- }
226
- ]
227
-
228
- messages_vlm[1]["content"].append({"type": "text", "text": query})
229
- if image and image.size[0] > 0 and image.size[1] > 0:
230
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
231
- logger.info(f"Received valid image for processing")
232
- else:
233
- logger.info("No valid image provided, processing text only")
234
-
235
- try:
236
- inputs_vlm = self.processor.apply_chat_template(
237
- messages_vlm,
238
- add_generation_prompt=True,
239
- tokenize=True,
240
- return_dict=True,
241
- return_tensors="pt"
242
- ).to(self.device, dtype=torch.bfloat16)
243
- except Exception as e:
244
- logger.error(f"Error in apply_chat_template: {str(e)}")
245
- raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
246
-
247
- input_len = inputs_vlm["input_ids"].shape[-1]
248
-
249
- with torch.inference_mode():
250
- generation = self.model.generate(
251
- **inputs_vlm,
252
- max_new_tokens=512,
253
- do_sample=True,
254
- temperature=0.7
255
- )
256
- generation = generation[0][input_len:]
257
-
258
- decoded = self.processor.decode(generation, skip_special_tokens=True)
259
- logger.info(f"Chat_v2 response: {decoded}")
260
- return decoded
261
-
262
- # TTS Manager
263
- class TTSManager:
264
- def __init__(self, device_type=device):
265
- self.device_type = device_type
266
- self.model = None
267
- self.repo_id = "ai4bharat/IndicF5"
268
-
269
- async def load(self):
270
- if not self.model:
271
- logger.info("Loading TTS model IndicF5 asynchronously...")
272
- self.model = await asyncio.to_thread(
273
- AutoModel.from_pretrained,
274
- self.repo_id,
275
- trust_remote_code=True
276
- )
277
- self.model = self.model.to(self.device_type)
278
- logger.info("TTS model IndicF5 loaded asynchronously")
279
-
280
- def synthesize(self, text, ref_audio_path, ref_text):
281
- if not self.model:
282
- raise ValueError("TTS model not loaded")
283
- return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
284
-
285
- # TTS Constants
286
- EXAMPLES = [
287
- {
288
- "audio_name": "KAN_F (Happy)",
289
- "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
290
- "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
291
- "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
292
- },
293
- ]
294
-
295
- # Pydantic models for TTS
296
- class SynthesizeRequest(BaseModel):
297
- text: str
298
- ref_audio_name: str
299
- ref_text: str = None
300
-
301
- class KannadaSynthesizeRequest(BaseModel):
302
- text: str
303
-
304
- # TTS Functions
305
- def load_audio_from_url(url: str):
306
- response = requests.get(url)
307
- if response.status_code == 200:
308
- audio_data, sample_rate = sf.read(io.BytesIO(response.content))
309
- return sample_rate, audio_data
310
- raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
311
-
312
- def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
313
- ref_audio_url = None
314
- for example in EXAMPLES:
315
- if example["audio_name"] == ref_audio_name:
316
- ref_audio_url = example["audio_url"]
317
- if not ref_text:
318
- ref_text = example["ref_text"]
319
- break
320
-
321
- if not ref_audio_url:
322
- raise HTTPException(status_code=400, detail="Invalid reference audio name.")
323
- if not text.strip():
324
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
325
- if not ref_text or not ref_text.strip():
326
- raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
327
-
328
- sample_rate, audio_data = load_audio_from_url(ref_audio_url)
329
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
330
- sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
331
- temp_audio.flush()
332
- audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
333
-
334
- if audio.dtype == np.int16:
335
- audio = audio.astype(np.float32) / 32768.0
336
- buffer = io.BytesIO()
337
- sf.write(buffer, audio, 24000, format='WAV')
338
- buffer.seek(0)
339
- return buffer
340
-
341
- # Supported languages
342
- SUPPORTED_LANGUAGES = {
343
- "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
344
- "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
345
- "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
346
- "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
347
- "kan_Knda", "ory_Orya",
348
- "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
349
- "por_Latn", "rus_Cyrl", "pol_Latn"
350
- }
351
-
352
- # Translation Manager
353
- class TranslateManager:
354
- def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
355
- self.device_type = device_type
356
- self.tokenizer = None
357
- self.model = None
358
- self.src_lang = src_lang
359
- self.tgt_lang = tgt_lang
360
- self.use_distilled = use_distilled
361
-
362
- async def load(self):
363
- if not self.tokenizer or not self.model:
364
- if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
365
- model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
366
- elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
367
- model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
368
- elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
369
- model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
370
- else:
371
- raise ValueError("Invalid language combination")
372
-
373
- self.tokenizer = await asyncio.to_thread(
374
- AutoTokenizer.from_pretrained,
375
- model_name,
376
- trust_remote_code=True
377
- )
378
- self.model = await asyncio.to_thread(
379
- AutoModelForSeq2SeqLM.from_pretrained,
380
- model_name,
381
- trust_remote_code=True,
382
- torch_dtype=torch.float16,
383
- attn_implementation="flash_attention_2"
384
- )
385
- self.model = self.model.to(self.device_type)
386
- self.model = torch.compile(self.model, mode="reduce-overhead")
387
- logger.info(f"Translation model {model_name} loaded asynchronously")
388
-
389
- class ModelManager:
390
- def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
391
- self.models = {}
392
- self.device_type = device_type
393
- self.use_distilled = use_distilled
394
- self.is_lazy_loading = is_lazy_loading
395
-
396
- async def load_model(self, src_lang, tgt_lang, key):
397
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} asynchronously")
398
- translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
399
- await translate_manager.load()
400
- self.models[key] = translate_manager
401
- logger.info(f"Loaded translation model for {key} asynchronously")
402
-
403
- def get_model(self, src_lang, tgt_lang):
404
- key = self._get_model_key(src_lang, tgt_lang)
405
- if key not in self.models:
406
- if self.is_lazy_loading:
407
- asyncio.create_task(self.load_model(src_lang, tgt_lang, key))
408
- else:
409
- raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
410
- return self.models.get(key)
411
-
412
- def _get_model_key(self, src_lang, tgt_lang):
413
- if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
414
- return 'eng_indic'
415
- elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
416
- return 'indic_eng'
417
- elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
418
- return 'indic_indic'
419
- raise ValueError("Invalid language combination")
420
-
421
- # ASR Manager
422
- class ASRModelManager:
423
- def __init__(self, device_type="cuda"):
424
- self.device_type = device_type
425
- self.model = None
426
- self.model_language = {"kannada": "kn"}
427
-
428
- async def load(self):
429
- if not self.model:
430
- logger.info("Loading ASR model asynchronously...")
431
- self.model = await asyncio.to_thread(
432
- AutoModel.from_pretrained,
433
- "ai4bharat/indic-conformer-600m-multilingual",
434
- trust_remote_code=True
435
- )
436
- self.model = self.model.to(self.device_type)
437
- logger.info("ASR model loaded asynchronously")
438
-
439
- # Global Managers
440
- llm_manager = LLMManager(settings.llm_model_name)
441
- model_manager = ModelManager()
442
- asr_manager = ASRModelManager()
443
- tts_manager = TTSManager()
444
- ip = IndicProcessor(inference=True)
445
-
446
- # Pydantic Models
447
- class ChatRequest(BaseModel):
448
- prompt: str
449
- src_lang: str = "kan_Knda"
450
- tgt_lang: str = "kan_Knda"
451
-
452
- @field_validator("prompt")
453
- def prompt_must_be_valid(cls, v):
454
- if len(v) > 1000:
455
- raise ValueError("Prompt cannot exceed 1000 characters")
456
- return v.strip()
457
-
458
- @field_validator("src_lang", "tgt_lang")
459
- def validate_language(cls, v):
460
- if v not in SUPPORTED_LANGUAGES:
461
- raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
462
- return v
463
-
464
- class ChatResponse(BaseModel):
465
- response: str
466
-
467
- class TranslationRequest(BaseModel):
468
- sentences: List[str]
469
- src_lang: str
470
- tgt_lang: str
471
-
472
- class TranscriptionResponse(BaseModel):
473
- text: str
474
-
475
- class TranslationResponse(BaseModel):
476
- translations: List[str]
477
-
478
- # Dependency
479
- def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
480
- return model_manager.get_model(src_lang, tgt_lang)
481
-
482
- # Lifespan Event Handler
483
- translation_configs = []
484
-
485
- @asynccontextmanager
486
- async def lifespan(app: FastAPI):
487
- async def load_all_models():
488
- try:
489
- tasks = [
490
- llm_manager.load(),
491
- tts_manager.load(),
492
- asr_manager.load(),
493
- ]
494
-
495
- translation_tasks = [
496
- model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic'),
497
- model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng'),
498
- model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic'),
499
- ]
500
-
501
- for config in translation_configs:
502
- src_lang = config["src_lang"]
503
- tgt_lang = config["tgt_lang"]
504
- key = model_manager._get_model_key(src_lang, tgt_lang)
505
- translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
506
-
507
- await asyncio.gather(*tasks, *translation_tasks)
508
- logger.info("All models loaded successfully asynchronously")
509
- except Exception as e:
510
- logger.error(f"Error loading models: {str(e)}")
511
- raise
512
-
513
- logger.info("Starting asynchronous model loading...")
514
- await load_all_models()
515
- yield
516
- llm_manager.unload()
517
- logger.info("Server shutdown complete")
518
-
519
- # FastAPI App
520
- app = FastAPI(
521
- title="Dhwani API",
522
- description="AI Chat API supporting Indian languages",
523
- version="1.0.0",
524
- redirect_slashes=False,
525
- lifespan=lifespan
526
- )
527
-
528
- app.add_middleware(
529
- CORSMiddleware,
530
- allow_origins=["*"],
531
- allow_credentials=False,
532
- allow_methods=["*"],
533
- allow_headers=["*"],
534
- )
535
-
536
- limiter = Limiter(key_func=get_remote_address)
537
- app.state.limiter = limiter
538
-
539
- # API Endpoints
540
- @app.post("/audio/speech", response_class=StreamingResponse)
541
- async def synthesize_kannada(request: KannadaSynthesizeRequest):
542
- if not tts_manager.model:
543
- raise HTTPException(status_code=503, detail="TTS model not loaded")
544
- kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
545
- if not request.text.strip():
546
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
547
-
548
- audio_buffer = synthesize_speech(
549
- tts_manager,
550
- text=request.text,
551
- ref_audio_name="KAN_F (Happy)",
552
- ref_text=kannada_example["ref_text"]
553
- )
554
-
555
- return StreamingResponse(
556
- audio_buffer,
557
- media_type="audio/wav",
558
- headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
559
- )
560
-
561
- @app.post("/translate", response_model=TranslationResponse)
562
- async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
563
- input_sentences = request.sentences
564
- src_lang = request.src_lang
565
- tgt_lang = request.tgt_lang
566
-
567
- if not input_sentences:
568
- raise HTTPException(status_code=400, detail="Input sentences are required")
569
-
570
- batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
571
- inputs = translate_manager.tokenizer(
572
- batch,
573
- truncation=True,
574
- padding="longest",
575
- return_tensors="pt",
576
- return_attention_mask=True,
577
- ).to(translate_manager.device_type)
578
-
579
- with torch.no_grad():
580
- generated_tokens = translate_manager.model.generate(
581
- **inputs,
582
- use_cache=True,
583
- min_length=0,
584
- max_length=256,
585
- num_beams=5,
586
- num_return_sequences=1,
587
- )
588
-
589
- with translate_manager.tokenizer.as_target_tokenizer():
590
- generated_tokens = translate_manager.tokenizer.batch_decode(
591
- generated_tokens.detach().cpu().tolist(),
592
- skip_special_tokens=True,
593
- clean_up_tokenization_spaces=True,
594
- )
595
-
596
- translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
597
- return TranslationResponse(translations=translations)
598
-
599
- async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
600
- try:
601
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
602
- except ValueError as e:
603
- logger.info(f"Model not preloaded: {str(e)}, loading now...")
604
- key = model_manager._get_model_key(src_lang, tgt_lang)
605
- await model_manager.load_model(src_lang, tgt_lang, key)
606
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
607
-
608
- if not translate_manager.model:
609
- await translate_manager.load()
610
-
611
- request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
612
- response = await translate(request, translate_manager)
613
- return response.translations
614
-
615
- @app.get("/v1/health")
616
- async def health_check():
617
- return {"status": "healthy", "model": settings.llm_model_name}
618
-
619
- @app.get("/")
620
- async def home():
621
- return RedirectResponse(url="/docs")
622
-
623
- @app.post("/v1/unload_all_models")
624
- async def unload_all_models():
625
- try:
626
- logger.info("Starting to unload all models...")
627
- llm_manager.unload()
628
- logger.info("All models unloaded successfully")
629
- return {"status": "success", "message": "All models unloaded"}
630
- except Exception as e:
631
- logger.error(f"Error unloading models: {str(e)}")
632
- raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
633
-
634
- @app.post("/v1/load_all_models")
635
- async def load_all_models():
636
- try:
637
- logger.info("Starting to load all models...")
638
- await llm_manager.load()
639
- logger.info("All models loaded successfully")
640
- return {"status": "success", "message": "All models loaded"}
641
- except Exception as e:
642
- logger.error(f"Error loading models: {str(e)}")
643
- raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
644
-
645
- @app.post("/v1/translate", response_model=TranslationResponse)
646
- async def translate_endpoint(request: TranslationRequest):
647
- logger.info(f"Received translation request: {request.dict()}")
648
- try:
649
- translations = await perform_internal_translation(
650
- sentences=request.sentences,
651
- src_lang=request.src_lang,
652
- tgt_lang=request.tgt_lang
653
- )
654
- logger.info(f"Translation successful: {translations}")
655
- return TranslationResponse(translations=translations)
656
- except Exception as e:
657
- logger.error(f"Unexpected error during translation: {str(e)}")
658
- raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
659
-
660
- @app.post("/v1/chat", response_model=ChatResponse)
661
- @limiter.limit(settings.chat_rate_limit)
662
- async def chat(request: Request, chat_request: ChatRequest):
663
- if not chat_request.prompt:
664
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
665
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
666
-
667
- EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
668
-
669
- try:
670
- if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
671
- translated_prompt = await perform_internal_translation(
672
- sentences=[chat_request.prompt],
673
- src_lang=chat_request.src_lang,
674
- tgt_lang="eng_Latn"
675
- )
676
- prompt_to_process = translated_prompt[0]
677
- logger.info(f"Translated prompt to English: {prompt_to_process}")
678
- else:
679
- prompt_to_process = chat_request.prompt
680
- logger.info("Prompt in English or European language, no translation needed")
681
-
682
- response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
683
- logger.info(f"Generated response: {response}")
684
-
685
- if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
686
- translated_response = await perform_internal_translation(
687
- sentences=[response],
688
- src_lang="eng_Latn",
689
- tgt_lang=chat_request.tgt_lang
690
- )
691
- final_response = translated_response[0]
692
- logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
693
- else:
694
- final_response = response
695
- logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
696
-
697
- return ChatResponse(response=final_response)
698
- except Exception as e:
699
- logger.error(f"Error processing request: {str(e)}")
700
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
701
-
702
- @app.post("/v1/visual_query/")
703
- async def visual_query(
704
- file: UploadFile = File(...),
705
- query: str = Body(...),
706
- src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
707
- tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
708
- ):
709
- try:
710
- image = Image.open(file.file)
711
- if image.size == (0, 0):
712
- raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
713
-
714
- if src_lang != "eng_Latn":
715
- translated_query = await perform_internal_translation(
716
- sentences=[query],
717
- src_lang=src_lang,
718
- tgt_lang="eng_Latn"
719
- )
720
- query_to_process = translated_query[0]
721
- logger.info(f"Translated query to English: {query_to_process}")
722
- else:
723
- query_to_process = query
724
- logger.info("Query already in English, no translation needed")
725
-
726
- answer = await llm_manager.vision_query(image, query_to_process)
727
- logger.info(f"Generated English answer: {answer}")
728
-
729
- if tgt_lang != "eng_Latn":
730
- translated_answer = await perform_internal_translation(
731
- sentences=[answer],
732
- src_lang="eng_Latn",
733
- tgt_lang=tgt_lang
734
- )
735
- final_answer = translated_answer[0]
736
- logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
737
- else:
738
- final_answer = answer
739
- logger.info("Answer kept in English, no translation needed")
740
-
741
- return {"answer": final_answer}
742
- except Exception as e:
743
- logger.error(f"Error processing request: {str(e)}")
744
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
745
-
746
- @app.post("/v1/chat_v2", response_model=ChatResponse)
747
- @limiter.limit(settings.chat_rate_limit)
748
- async def chat_v2(
749
- request: Request,
750
- prompt: str = Form(...),
751
- image: UploadFile = File(default=None),
752
- src_lang: str = Form("kan_Knda"),
753
- tgt_lang: str = Form("kan_Knda"),
754
- ):
755
- if not prompt:
756
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
757
- if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
758
- raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
759
-
760
- logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
761
-
762
- try:
763
- if image:
764
- image_data = await image.read()
765
- if not image_data:
766
- raise HTTPException(status_code=400, detail="Uploaded image is empty")
767
- img = Image.open(io.BytesIO(image_data))
768
-
769
- if src_lang != "eng_Latn":
770
- translated_prompt = await perform_internal_translation(
771
- sentences=[prompt],
772
- src_lang=src_lang,
773
- tgt_lang="eng_Latn"
774
- )
775
- prompt_to_process = translated_prompt[0]
776
- logger.info(f"Translated prompt to English: {prompt_to_process}")
777
- else:
778
- prompt_to_process = prompt
779
- logger.info("Prompt already in English, no translation needed")
780
-
781
- decoded = await llm_manager.chat_v2(img, prompt_to_process)
782
- logger.info(f"Generated English response: {decoded}")
783
-
784
- if tgt_lang != "eng_Latn":
785
- translated_response = await perform_internal_translation(
786
- sentences=[decoded],
787
- src_lang="eng_Latn",
788
- tgt_lang=tgt_lang
789
- )
790
- final_response = translated_response[0]
791
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
792
- else:
793
- final_response = decoded
794
- logger.info("Response kept in English, no translation needed")
795
- else:
796
- if src_lang != "eng_Latn":
797
- translated_prompt = await perform_internal_translation(
798
- sentences=[prompt],
799
- src_lang=src_lang,
800
- tgt_lang="eng_Latn"
801
- )
802
- prompt_to_process = translated_prompt[0]
803
- logger.info(f"Translated prompt to English: {prompt_to_process}")
804
- else:
805
- prompt_to_process = prompt
806
- logger.info("Prompt already in English, no translation needed")
807
-
808
- decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
809
- logger.info(f"Generated English response: {decoded}")
810
-
811
- if tgt_lang != "eng_Latn":
812
- translated_response = await perform_internal_translation(
813
- sentences=[decoded],
814
- src_lang="eng_Latn",
815
- tgt_lang=tgt_lang
816
- )
817
- final_response = translated_response[0]
818
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
819
- else:
820
- final_response = decoded
821
- logger.info("Response kept in English, no translation needed")
822
-
823
- return ChatResponse(response=final_response)
824
- except Exception as e:
825
- logger.error(f"Error processing request: {str(e)}")
826
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
827
-
828
- @app.post("/transcribe/", response_model=TranscriptionResponse)
829
- async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
830
- if not asr_manager.model:
831
- raise HTTPException(status_code=503, detail="ASR model not loaded")
832
- try:
833
- wav, sr = torchaudio.load(file.file)
834
- wav = torch.mean(wav, dim=0, keepdim=True)
835
- target_sample_rate = 16000
836
- if sr != target_sample_rate:
837
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
838
- wav = resampler(wav)
839
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
840
- return TranscriptionResponse(text=transcription_rnnt)
841
- except Exception as e:
842
- logger.error(f"Error in transcription: {str(e)}")
843
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
844
-
845
- @app.post("/v1/speech_to_speech")
846
- async def speech_to_speech(
847
- request: Request,
848
- file: UploadFile = File(...),
849
- language: str = Query(..., enum=list(asr_manager.model_language.keys())),
850
- ) -> StreamingResponse:
851
- if not tts_manager.model:
852
- raise HTTPException(status_code=503, detail="TTS model not loaded")
853
- transcription = await transcribe_audio(file, language)
854
- logger.info(f"Transcribed text: {transcription.text}")
855
-
856
- chat_request = ChatRequest(
857
- prompt=transcription.text,
858
- src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
859
- tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
860
- )
861
- processed_text = await chat(request, chat_request)
862
- logger.info(f"Processed text: {processed_text.response}")
863
-
864
- voice_request = KannadaSynthesizeRequest(text=processed_text.response)
865
- audio_response = await synthesize_kannada(voice_request)
866
- return audio_response
867
-
868
- LANGUAGE_TO_SCRIPT = {
869
- "kannada": "kan_Knda"
870
- }
871
-
872
- # Main Execution
873
- if __name__ == "__main__":
874
- parser = argparse.ArgumentParser(description="Run the FastAPI server.")
875
- parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
876
- parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
877
- parser.add_argument("--config", type=str, default="config_one", help="Configuration to use")
878
- args = parser.parse_args()
879
-
880
- def load_config(config_path="dhwani_config.json"):
881
- with open(config_path, "r") as f:
882
- return json.load(f)
883
-
884
- config_data = load_config()
885
- if args.config not in config_data["configs"]:
886
- raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
887
-
888
- selected_config = config_data["configs"][args.config]
889
- global_settings = config_data["global_settings"]
890
-
891
- settings.llm_model_name = selected_config["components"]["LLM"]["model"]
892
- settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
893
- settings.host = global_settings["host"]
894
- settings.port = global_settings["port"]
895
- settings.chat_rate_limit = global_settings["chat_rate_limit"]
896
- settings.speech_rate_limit = global_settings["speech_rate_limit"]
897
-
898
- llm_manager = LLMManager(settings.llm_model_name)
899
-
900
- if selected_config["components"]["ASR"]:
901
- asr_model_name = selected_config["components"]["ASR"]["model"]
902
- asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
903
-
904
- if selected_config["components"]["Translation"]:
905
- translation_configs.extend(selected_config["components"]["Translation"])
906
-
907
- host = args.host if args.host != settings.host else settings.host
908
- port = args.port if args.port != settings.port else settings.port
909
-
910
- uvicorn.run(app, host=host, port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/server/main_local.py DELETED
@@ -1,913 +0,0 @@
1
- import argparse
2
- import io
3
- import os
4
- from time import time
5
- from typing import List
6
- import tempfile
7
- import uvicorn
8
- from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
- from PIL import Image
12
- from pydantic import BaseModel, field_validator
13
- from pydantic_settings import BaseSettings
14
- from slowapi import Limiter
15
- from slowapi.util import get_remote_address
16
- import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, AutoModel, BitsAndBytesConfig, Gemma3ForConditionalGeneration
18
- from IndicTransToolkit import IndicProcessor
19
- import json
20
- import asyncio
21
- from contextlib import asynccontextmanager
22
- import soundfile as sf
23
- import numpy as np
24
- import requests
25
- from starlette.responses import StreamingResponse
26
- from logging_config import logger
27
- from tts_config import SPEED, ResponseFormat, config as tts_config
28
- import torchaudio
29
-
30
- # Device setup
31
- if torch.cuda.is_available():
32
- device = "cuda:0"
33
- logger.info("GPU will be used for inference")
34
- else:
35
- device = "cpu"
36
- logger.info("CPU will be used for inference")
37
- torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
-
39
- # Check CUDA availability and version
40
- cuda_available = torch.cuda.is_available()
41
- cuda_version = torch.version.cuda if cuda_available else None
42
-
43
- if torch.cuda.is_available():
44
- device_idx = torch.cuda.current_device()
45
- capability = torch.cuda.get_device_capability(device_idx)
46
- compute_capability_float = float(f"{capability[0]}.{capability[1]}")
47
- print(f"CUDA version: {cuda_version}")
48
- print(f"CUDA Compute Capability: {compute_capability_float}")
49
- else:
50
- print("CUDA is not available on this system.")
51
-
52
- # Settings
53
- class Settings(BaseSettings):
54
- llm_model_name: str = "google/gemma-3-4b-it"
55
- max_tokens: int = 512
56
- host: str = "0.0.0.0"
57
- port: int = 7860
58
- chat_rate_limit: str = "100/minute"
59
- speech_rate_limit: str = "5/minute"
60
-
61
- @field_validator("chat_rate_limit", "speech_rate_limit")
62
- def validate_rate_limit(cls, v):
63
- if not v.count("/") == 1 or not v.split("/")[0].isdigit():
64
- raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
65
- return v
66
-
67
- class Config:
68
- env_file = ".env"
69
-
70
- settings = Settings()
71
-
72
- # Quantization config for LLM
73
- quantization_config = BitsAndBytesConfig(
74
- load_in_4bit=True,
75
- bnb_4bit_quant_type="nf4",
76
- bnb_4bit_use_double_quant=True,
77
- bnb_4bit_compute_dtype=torch.bfloat16
78
- )
79
-
80
- # LLM Manager
81
- class LLMManager:
82
- def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
- self.model_name = model_name
84
- self.device = torch.device(device)
85
- self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
- self.model = None
87
- self.processor = None
88
- self.is_loaded = False
89
- logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
-
91
- async def load(self):
92
- if not self.is_loaded:
93
- try:
94
- local_path = "/app/models/llm_model"
95
- self.model = await asyncio.to_thread(
96
- Gemma3ForConditionalGeneration.from_pretrained,
97
- local_path,
98
- device_map="auto",
99
- quantization_config=quantization_config,
100
- torch_dtype=self.torch_dtype
101
- )
102
- self.model.eval()
103
- self.processor = await asyncio.to_thread(
104
- AutoProcessor.from_pretrained,
105
- local_path
106
- )
107
- self.is_loaded = True
108
- logger.info(f"LLM loaded from {local_path} on {self.device}")
109
- except Exception as e:
110
- logger.error(f"Failed to load LLM: {str(e)}")
111
- raise
112
-
113
- def unload(self):
114
- if self.is_loaded:
115
- del self.model
116
- del self.processor
117
- if self.device.type == "cuda":
118
- torch.cuda.empty_cache()
119
- logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
120
- self.is_loaded = False
121
- logger.info(f"LLM {self.model_name} unloaded from {self.device}")
122
-
123
- async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
124
- if not self.is_loaded:
125
- await self.load()
126
-
127
- messages_vlm = [
128
- {
129
- "role": "system",
130
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
131
- },
132
- {
133
- "role": "user",
134
- "content": [{"type": "text", "text": prompt}]
135
- }
136
- ]
137
-
138
- try:
139
- inputs_vlm = self.processor.apply_chat_template(
140
- messages_vlm,
141
- add_generation_prompt=True,
142
- tokenize=True,
143
- return_dict=True,
144
- return_tensors="pt"
145
- ).to(self.device, dtype=torch.bfloat16)
146
- except Exception as e:
147
- logger.error(f"Error in tokenization: {str(e)}")
148
- raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
149
-
150
- input_len = inputs_vlm["input_ids"].shape[-1]
151
-
152
- with torch.inference_mode():
153
- generation = self.model.generate(
154
- **inputs_vlm,
155
- max_new_tokens=max_tokens,
156
- do_sample=True,
157
- temperature=temperature
158
- )
159
- generation = generation[0][input_len:]
160
-
161
- response = self.processor.decode(generation, skip_special_tokens=True)
162
- logger.info(f"Generated response: {response}")
163
- return response
164
-
165
- async def vision_query(self, image: Image.Image, query: str) -> str:
166
- if not self.is_loaded:
167
- await self.load()
168
-
169
- messages_vlm = [
170
- {
171
- "role": "system",
172
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
173
- },
174
- {
175
- "role": "user",
176
- "content": []
177
- }
178
- ]
179
-
180
- messages_vlm[1]["content"].append({"type": "text", "text": query})
181
- if image and image.size[0] > 0 and image.size[1] > 0:
182
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
183
- logger.info(f"Received valid image for processing")
184
- else:
185
- logger.info("No valid image provided, processing text only")
186
-
187
- try:
188
- inputs_vlm = self.processor.apply_chat_template(
189
- messages_vlm,
190
- add_generation_prompt=True,
191
- tokenize=True,
192
- return_dict=True,
193
- return_tensors="pt"
194
- ).to(self.device, dtype=torch.bfloat16)
195
- except Exception as e:
196
- logger.error(f"Error in apply_chat_template: {str(e)}")
197
- raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
198
-
199
- input_len = inputs_vlm["input_ids"].shape[-1]
200
-
201
- with torch.inference_mode():
202
- generation = self.model.generate(
203
- **inputs_vlm,
204
- max_new_tokens=512,
205
- do_sample=True,
206
- temperature=0.7
207
- )
208
- generation = generation[0][input_len:]
209
-
210
- decoded = self.processor.decode(generation, skip_special_tokens=True)
211
- logger.info(f"Vision query response: {decoded}")
212
- return decoded
213
-
214
- async def chat_v2(self, image: Image.Image, query: str) -> str:
215
- if not self.is_loaded:
216
- await self.load()
217
-
218
- messages_vlm = [
219
- {
220
- "role": "system",
221
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
222
- },
223
- {
224
- "role": "user",
225
- "content": []
226
- }
227
- ]
228
-
229
- messages_vlm[1]["content"].append({"type": "text", "text": query})
230
- if image and image.size[0] > 0 and image.size[1] > 0:
231
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
232
- logger.info(f"Received valid image for processing")
233
- else:
234
- logger.info("No valid image provided, processing text only")
235
-
236
- try:
237
- inputs_vlm = self.processor.apply_chat_template(
238
- messages_vlm,
239
- add_generation_prompt=True,
240
- tokenize=True,
241
- return_dict=True,
242
- return_tensors="pt"
243
- ).to(self.device, dtype=torch.bfloat16)
244
- except Exception as e:
245
- logger.error(f"Error in apply_chat_template: {str(e)}")
246
- raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
247
-
248
- input_len = inputs_vlm["input_ids"].shape[-1]
249
-
250
- with torch.inference_mode():
251
- generation = self.model.generate(
252
- **inputs_vlm,
253
- max_new_tokens=512,
254
- do_sample=True,
255
- temperature=0.7
256
- )
257
- generation = generation[0][input_len:]
258
-
259
- decoded = self.processor.decode(generation, skip_special_tokens=True)
260
- logger.info(f"Chat_v2 response: {decoded}")
261
- return decoded
262
-
263
- # TTS Manager
264
- class TTSManager:
265
- def __init__(self, device_type=device):
266
- self.device_type = device_type
267
- self.model = None
268
- self.repo_id = "ai4bharat/IndicF5"
269
-
270
- async def load(self):
271
- if not self.model:
272
- logger.info("Loading TTS model from local path asynchronously...")
273
- local_path = "/app/models/tts_model"
274
- self.model = await asyncio.to_thread(
275
- AutoModel.from_pretrained,
276
- local_path,
277
- trust_remote_code=True
278
- )
279
- self.model = self.model.to(self.device_type)
280
- logger.info("TTS model loaded from local path asynchronously")
281
-
282
- def synthesize(self, text, ref_audio_path, ref_text):
283
- if not self.model:
284
- raise ValueError("TTS model not loaded")
285
- return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
286
-
287
- # TTS Constants
288
- EXAMPLES = [
289
- {
290
- "audio_name": "KAN_F (Happy)",
291
- "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
292
- "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
293
- "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
294
- },
295
- ]
296
-
297
- # Pydantic models for TTS
298
- class SynthesizeRequest(BaseModel):
299
- text: str
300
- ref_audio_name: str
301
- ref_text: str = None
302
-
303
- class KannadaSynthesizeRequest(BaseModel):
304
- text: str
305
-
306
- # TTS Functions
307
- def load_audio_from_url(url: str):
308
- response = requests.get(url)
309
- if response.status_code == 200:
310
- audio_data, sample_rate = sf.read(io.BytesIO(response.content))
311
- return sample_rate, audio_data
312
- raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
313
-
314
- def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
315
- ref_audio_url = None
316
- for example in EXAMPLES:
317
- if example["audio_name"] == ref_audio_name:
318
- ref_audio_url = example["audio_url"]
319
- if not ref_text:
320
- ref_text = example["ref_text"]
321
- break
322
-
323
- if not ref_audio_url:
324
- raise HTTPException(status_code=400, detail="Invalid reference audio name.")
325
- if not text.strip():
326
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
327
- if not ref_text or not ref_text.strip():
328
- raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
329
-
330
- sample_rate, audio_data = load_audio_from_url(ref_audio_url)
331
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
332
- sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
333
- temp_audio.flush()
334
- audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
335
-
336
- if audio.dtype == np.int16:
337
- audio = audio.astype(np.float32) / 32768.0
338
- buffer = io.BytesIO()
339
- sf.write(buffer, audio, 24000, format='WAV')
340
- buffer.seek(0)
341
- return buffer
342
-
343
- # Supported languages
344
- SUPPORTED_LANGUAGES = {
345
- "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
346
- "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
347
- "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
348
- "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
349
- "kan_Knda", "ory_Orya",
350
- "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
351
- "por_Latn", "rus_Cyrl", "pol_Latn"
352
- }
353
-
354
- # Translation Manager
355
- class TranslateManager:
356
- def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
357
- self.device_type = device_type
358
- self.tokenizer = None
359
- self.model = None
360
- self.src_lang = src_lang
361
- self.tgt_lang = tgt_lang
362
- self.use_distilled = use_distilled
363
-
364
- async def load(self):
365
- if not self.tokenizer or not self.model:
366
- if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
367
- local_path = "/app/models/trans_en_indic"
368
- elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
369
- local_path = "/app/models/trans_indic_en"
370
- elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
371
- local_path = "/app/models/trans_indic_indic"
372
- else:
373
- raise ValueError("Invalid language combination")
374
-
375
- self.tokenizer = await asyncio.to_thread(
376
- AutoTokenizer.from_pretrained,
377
- local_path,
378
- trust_remote_code=True
379
- )
380
- self.model = await asyncio.to_thread(
381
- AutoModelForSeq2SeqLM.from_pretrained,
382
- local_path,
383
- trust_remote_code=True,
384
- torch_dtype=torch.float16,
385
- attn_implementation="flash_attention_2"
386
- )
387
- self.model = self.model.to(self.device_type)
388
- self.model = torch.compile(self.model, mode="reduce-overhead")
389
- logger.info(f"Translation model loaded from {local_path} asynchronously")
390
-
391
- class ModelManager:
392
- def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
393
- self.models = {}
394
- self.device_type = device_type
395
- self.use_distilled = use_distilled
396
- self.is_lazy_loading = is_lazy_loading
397
-
398
- async def load_model(self, src_lang, tgt_lang, key):
399
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} from local path")
400
- translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
401
- await translate_manager.load()
402
- self.models[key] = translate_manager
403
- logger.info(f"Loaded translation model for {key} from local path")
404
-
405
- def get_model(self, src_lang, tgt_lang):
406
- key = self._get_model_key(src_lang, tgt_lang)
407
- if key not in self.models:
408
- if self.is_lazy_loading:
409
- asyncio.create_task(self.load_model(src_lang, tgt_lang, key))
410
- else:
411
- raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
412
- return self.models.get(key)
413
-
414
- def _get_model_key(self, src_lang, tgt_lang):
415
- if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
416
- return 'eng_indic'
417
- elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
418
- return 'indic_eng'
419
- elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
420
- return 'indic_indic'
421
- raise ValueError("Invalid language combination")
422
-
423
- # ASR Manager
424
- class ASRModelManager:
425
- def __init__(self, device_type="cuda"):
426
- self.device_type = device_type
427
- self.model = None
428
- self.model_language = {"kannada": "kn"}
429
-
430
- async def load(self):
431
- if not self.model:
432
- logger.info("Loading ASR model from local path asynchronously...")
433
- local_path = "/app/models/asr_model"
434
- self.model = await asyncio.to_thread(
435
- AutoModel.from_pretrained,
436
- local_path,
437
- trust_remote_code=True
438
- )
439
- self.model = self.model.to(self.device_type)
440
- logger.info("ASR model loaded from local path asynchronously")
441
-
442
- # Global Managers
443
- llm_manager = LLMManager(settings.llm_model_name)
444
- model_manager = ModelManager()
445
- asr_manager = ASRModelManager()
446
- tts_manager = TTSManager()
447
- ip = IndicProcessor(inference=True)
448
-
449
- # Pydantic Models
450
- class ChatRequest(BaseModel):
451
- prompt: str
452
- src_lang: str = "kan_Knda"
453
- tgt_lang: str = "kan_Knda"
454
-
455
- @field_validator("prompt")
456
- def prompt_must_be_valid(cls, v):
457
- if len(v) > 1000:
458
- raise ValueError("Prompt cannot exceed 1000 characters")
459
- return v.strip()
460
-
461
- @field_validator("src_lang", "tgt_lang")
462
- def validate_language(cls, v):
463
- if v not in SUPPORTED_LANGUAGES:
464
- raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
465
- return v
466
-
467
- class ChatResponse(BaseModel):
468
- response: str
469
-
470
- class TranslationRequest(BaseModel):
471
- sentences: List[str]
472
- src_lang: str
473
- tgt_lang: str
474
-
475
- class TranscriptionResponse(BaseModel):
476
- text: str
477
-
478
- class TranslationResponse(BaseModel):
479
- translations: List[str]
480
-
481
- # Dependency
482
- def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
483
- return model_manager.get_model(src_lang, tgt_lang)
484
-
485
- # Lifespan Event Handler
486
- translation_configs = []
487
-
488
- @asynccontextmanager
489
- async def lifespan(app: FastAPI):
490
- async def load_all_models():
491
- try:
492
- tasks = [
493
- llm_manager.load(),
494
- tts_manager.load(),
495
- asr_manager.load(),
496
- ]
497
-
498
- translation_tasks = [
499
- model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic'),
500
- model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng'),
501
- model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic'),
502
- ]
503
-
504
- for config in translation_configs:
505
- src_lang = config["src_lang"]
506
- tgt_lang = config["tgt_lang"]
507
- key = model_manager._get_model_key(src_lang, tgt_lang)
508
- translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
509
-
510
- await asyncio.gather(*tasks, *translation_tasks)
511
- logger.info("All models loaded successfully from local paths")
512
- except Exception as e:
513
- logger.error(f"Error loading models: {str(e)}")
514
- raise
515
-
516
- logger.info("Starting asynchronous model loading from local paths...")
517
- await load_all_models()
518
- yield
519
- llm_manager.unload()
520
- logger.info("Server shutdown complete")
521
-
522
- # FastAPI App
523
- app = FastAPI(
524
- title="Dhwani API",
525
- description="AI Chat API supporting Indian languages",
526
- version="1.0.0",
527
- redirect_slashes=False,
528
- lifespan=lifespan
529
- )
530
-
531
- app.add_middleware(
532
- CORSMiddleware,
533
- allow_origins=["*"],
534
- allow_credentials=False,
535
- allow_methods=["*"],
536
- allow_headers=["*"],
537
- )
538
-
539
- limiter = Limiter(key_func=get_remote_address)
540
- app.state.limiter = limiter
541
-
542
- # API Endpoints
543
- @app.post("/audio/speech", response_class=StreamingResponse)
544
- async def synthesize_kannada(request: KannadaSynthesizeRequest):
545
- if not tts_manager.model:
546
- raise HTTPException(status_code=503, detail="TTS model not loaded")
547
- kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
548
- if not request.text.strip():
549
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
550
-
551
- audio_buffer = synthesize_speech(
552
- tts_manager,
553
- text=request.text,
554
- ref_audio_name="KAN_F (Happy)",
555
- ref_text=kannada_example["ref_text"]
556
- )
557
-
558
- return StreamingResponse(
559
- audio_buffer,
560
- media_type="audio/wav",
561
- headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
562
- )
563
-
564
- @app.post("/translate", response_model=TranslationResponse)
565
- async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
566
- input_sentences = request.sentences
567
- src_lang = request.src_lang
568
- tgt_lang = request.tgt_lang
569
-
570
- if not input_sentences:
571
- raise HTTPException(status_code=400, detail="Input sentences are required")
572
-
573
- batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
574
- inputs = translate_manager.tokenizer(
575
- batch,
576
- truncation=True,
577
- padding="longest",
578
- return_tensors="pt",
579
- return_attention_mask=True,
580
- ).to(translate_manager.device_type)
581
-
582
- with torch.no_grad():
583
- generated_tokens = translate_manager.model.generate(
584
- **inputs,
585
- use_cache=True,
586
- min_length=0,
587
- max_length=256,
588
- num_beams=5,
589
- num_return_sequences=1,
590
- )
591
-
592
- with translate_manager.tokenizer.as_target_tokenizer():
593
- generated_tokens = translate_manager.tokenizer.batch_decode(
594
- generated_tokens.detach().cpu().tolist(),
595
- skip_special_tokens=True,
596
- clean_up_tokenization_spaces=True,
597
- )
598
-
599
- translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
600
- return TranslationResponse(translations=translations)
601
-
602
- async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
603
- try:
604
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
605
- except ValueError as e:
606
- logger.info(f"Model not preloaded: {str(e)}, loading now...")
607
- key = model_manager._get_model_key(src_lang, tgt_lang)
608
- await model_manager.load_model(src_lang, tgt_lang, key)
609
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
610
-
611
- if not translate_manager.model:
612
- await translate_manager.load()
613
-
614
- request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
615
- response = await translate(request, translate_manager)
616
- return response.translations
617
-
618
- @app.get("/v1/health")
619
- async def health_check():
620
- return {"status": "healthy", "model": settings.llm_model_name}
621
-
622
- @app.get("/")
623
- async def home():
624
- return RedirectResponse(url="/docs")
625
-
626
- @app.post("/v1/unload_all_models")
627
- async def unload_all_models():
628
- try:
629
- logger.info("Starting to unload all models...")
630
- llm_manager.unload()
631
- logger.info("All models unloaded successfully")
632
- return {"status": "success", "message": "All models unloaded"}
633
- except Exception as e:
634
- logger.error(f"Error unloading models: {str(e)}")
635
- raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
636
-
637
- @app.post("/v1/load_all_models")
638
- async def load_all_models():
639
- try:
640
- logger.info("Starting to load all models...")
641
- await llm_manager.load()
642
- logger.info("All models loaded successfully")
643
- return {"status": "success", "message": "All models loaded"}
644
- except Exception as e:
645
- logger.error(f"Error loading models: {str(e)}")
646
- raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
647
-
648
- @app.post("/v1/translate", response_model=TranslationResponse)
649
- async def translate_endpoint(request: TranslationRequest):
650
- logger.info(f"Received translation request: {request.dict()}")
651
- try:
652
- translations = await perform_internal_translation(
653
- sentences=request.sentences,
654
- src_lang=request.src_lang,
655
- tgt_lang=request.tgt_lang
656
- )
657
- logger.info(f"Translation successful: {translations}")
658
- return TranslationResponse(translations=translations)
659
- except Exception as e:
660
- logger.error(f"Unexpected error during translation: {str(e)}")
661
- raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
662
-
663
- @app.post("/v1/chat", response_model=ChatResponse)
664
- @limiter.limit(settings.chat_rate_limit)
665
- async def chat(request: Request, chat_request: ChatRequest):
666
- if not chat_request.prompt:
667
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
668
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
669
-
670
- EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
671
-
672
- try:
673
- if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
674
- translated_prompt = await perform_internal_translation(
675
- sentences=[chat_request.prompt],
676
- src_lang=chat_request.src_lang,
677
- tgt_lang="eng_Latn"
678
- )
679
- prompt_to_process = translated_prompt[0]
680
- logger.info(f"Translated prompt to English: {prompt_to_process}")
681
- else:
682
- prompt_to_process = chat_request.prompt
683
- logger.info("Prompt in English or European language, no translation needed")
684
-
685
- response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
686
- logger.info(f"Generated response: {response}")
687
-
688
- if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
689
- translated_response = await perform_internal_translation(
690
- sentences=[response],
691
- src_lang="eng_Latn",
692
- tgt_lang=chat_request.tgt_lang
693
- )
694
- final_response = translated_response[0]
695
- logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
696
- else:
697
- final_response = response
698
- logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
699
-
700
- return ChatResponse(response=final_response)
701
- except Exception as e:
702
- logger.error(f"Error processing request: {str(e)}")
703
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
704
-
705
- @app.post("/v1/visual_query/")
706
- async def visual_query(
707
- file: UploadFile = File(...),
708
- query: str = Body(...),
709
- src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
710
- tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
711
- ):
712
- try:
713
- image = Image.open(file.file)
714
- if image.size == (0, 0):
715
- raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
716
-
717
- if src_lang != "eng_Latn":
718
- translated_query = await perform_internal_translation(
719
- sentences=[query],
720
- src_lang=src_lang,
721
- tgt_lang="eng_Latn"
722
- )
723
- query_to_process = translated_query[0]
724
- logger.info(f"Translated query to English: {query_to_process}")
725
- else:
726
- query_to_process = query
727
- logger.info("Query already in English, no translation needed")
728
-
729
- answer = await llm_manager.vision_query(image, query_to_process)
730
- logger.info(f"Generated English answer: {answer}")
731
-
732
- if tgt_lang != "eng_Latn":
733
- translated_answer = await perform_internal_translation(
734
- sentences=[answer],
735
- src_lang="eng_Latn",
736
- tgt_lang=tgt_lang
737
- )
738
- final_answer = translated_answer[0]
739
- logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
740
- else:
741
- final_answer = answer
742
- logger.info("Answer kept in English, no translation needed")
743
-
744
- return {"answer": final_answer}
745
- except Exception as e:
746
- logger.error(f"Error processing request: {str(e)}")
747
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
748
-
749
- @app.post("/v1/chat_v2", response_model=ChatResponse)
750
- @limiter.limit(settings.chat_rate_limit)
751
- async def chat_v2(
752
- request: Request,
753
- prompt: str = Form(...),
754
- image: UploadFile = File(default=None),
755
- src_lang: str = Form("kan_Knda"),
756
- tgt_lang: str = Form("kan_Knda"),
757
- ):
758
- if not prompt:
759
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
760
- if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
761
- raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
762
-
763
- logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
764
-
765
- try:
766
- if image:
767
- image_data = await image.read()
768
- if not image_data:
769
- raise HTTPException(status_code=400, detail="Uploaded image is empty")
770
- img = Image.open(io.BytesIO(image_data))
771
-
772
- if src_lang != "eng_Latn":
773
- translated_prompt = await perform_internal_translation(
774
- sentences=[prompt],
775
- src_lang=src_lang,
776
- tgt_lang="eng_Latn"
777
- )
778
- prompt_to_process = translated_prompt[0]
779
- logger.info(f"Translated prompt to English: {prompt_to_process}")
780
- else:
781
- prompt_to_process = prompt
782
- logger.info("Prompt already in English, no translation needed")
783
-
784
- decoded = await llm_manager.chat_v2(img, prompt_to_process)
785
- logger.info(f"Generated English response: {decoded}")
786
-
787
- if tgt_lang != "eng_Latn":
788
- translated_response = await perform_internal_translation(
789
- sentences=[decoded],
790
- src_lang="eng_Latn",
791
- tgt_lang=tgt_lang
792
- )
793
- final_response = translated_response[0]
794
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
795
- else:
796
- final_response = decoded
797
- logger.info("Response kept in English, no translation needed")
798
- else:
799
- if src_lang != "eng_Latn":
800
- translated_prompt = await perform_internal_translation(
801
- sentences=[prompt],
802
- src_lang=src_lang,
803
- tgt_lang="eng_Latn"
804
- )
805
- prompt_to_process = translated_prompt[0]
806
- logger.info(f"Translated prompt to English: {prompt_to_process}")
807
- else:
808
- prompt_to_process = prompt
809
- logger.info("Prompt already in English, no translation needed")
810
-
811
- decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
812
- logger.info(f"Generated English response: {decoded}")
813
-
814
- if tgt_lang != "eng_Latn":
815
- translated_response = await perform_internal_translation(
816
- sentences=[decoded],
817
- src_lang="eng_Latn",
818
- tgt_lang=tgt_lang
819
- )
820
- final_response = translated_response[0]
821
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
822
- else:
823
- final_response = decoded
824
- logger.info("Response kept in English, no translation needed")
825
-
826
- return ChatResponse(response=final_response)
827
- except Exception as e:
828
- logger.error(f"Error processing request: {str(e)}")
829
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
830
-
831
- @app.post("/transcribe/", response_model=TranscriptionResponse)
832
- async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
833
- if not asr_manager.model:
834
- raise HTTPException(status_code=503, detail="ASR model not loaded")
835
- try:
836
- wav, sr = torchaudio.load(file.file)
837
- wav = torch.mean(wav, dim=0, keepdim=True)
838
- target_sample_rate = 16000
839
- if sr != target_sample_rate:
840
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
841
- wav = resampler(wav)
842
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
843
- return TranscriptionResponse(text=transcription_rnnt)
844
- except Exception as e:
845
- logger.error(f"Error in transcription: {str(e)}")
846
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
847
-
848
- @app.post("/v1/speech_to_speech")
849
- async def speech_to_speech(
850
- request: Request,
851
- file: UploadFile = File(...),
852
- language: str = Query(..., enum=list(asr_manager.model_language.keys())),
853
- ) -> StreamingResponse:
854
- if not tts_manager.model:
855
- raise HTTPException(status_code=503, detail="TTS model not loaded")
856
- transcription = await transcribe_audio(file, language)
857
- logger.info(f"Transcribed text: {transcription.text}")
858
-
859
- chat_request = ChatRequest(
860
- prompt=transcription.text,
861
- src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
862
- tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
863
- )
864
- processed_text = await chat(request, chat_request)
865
- logger.info(f"Processed text: {processed_text.response}")
866
-
867
- voice_request = KannadaSynthesizeRequest(text=processed_text.response)
868
- audio_response = await synthesize_kannada(voice_request)
869
- return audio_response
870
-
871
- LANGUAGE_TO_SCRIPT = {
872
- "kannada": "kan_Knda"
873
- }
874
-
875
- # Main Execution
876
- if __name__ == "__main__":
877
- parser = argparse.ArgumentParser(description="Run the FastAPI server.")
878
- parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
879
- parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
880
- parser.add_argument("--config", type=str, default="config_one", help="Configuration to use")
881
- args = parser.parse_args()
882
-
883
- def load_config(config_path="dhwani_config.json"):
884
- with open(config_path, "r") as f:
885
- return json.load(f)
886
-
887
- config_data = load_config()
888
- if args.config not in config_data["configs"]:
889
- raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
890
-
891
- selected_config = config_data["configs"][args.config]
892
- global_settings = config_data["global_settings"]
893
-
894
- settings.llm_model_name = selected_config["components"]["LLM"]["model"]
895
- settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
896
- settings.host = global_settings["host"]
897
- settings.port = global_settings["port"]
898
- settings.chat_rate_limit = global_settings["chat_rate_limit"]
899
- settings.speech_rate_limit = global_settings["speech_rate_limit"]
900
-
901
- llm_manager = LLMManager(settings.llm_model_name)
902
-
903
- if selected_config["components"]["ASR"]:
904
- asr_model_name = selected_config["components"]["ASR"]["model"]
905
- asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
906
-
907
- if selected_config["components"]["Translation"]:
908
- translation_configs.extend(selected_config["components"]["Translation"])
909
-
910
- host = args.host if args.host != settings.host else settings.host
911
- port = args.port if args.port != settings.port else settings.port
912
-
913
- uvicorn.run(app, host=host, port=port)