sachin commited on
Commit
8308b9e
·
1 Parent(s): 98b17b0

test-dhwani-base-model

Browse files
Files changed (6) hide show
  1. Dockerfile +1 -41
  2. Dockerfile.app +17 -0
  3. Dockerfile.base +32 -0
  4. Dockerfile.models +16 -0
  5. src/server/main.py +20 -23
  6. src/server/main_local.py +913 -0
Dockerfile CHANGED
@@ -1,49 +1,9 @@
1
- FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04
2
  WORKDIR /app
3
 
4
- # Install system dependencies
5
- RUN apt-get update && apt-get install -y \
6
- python3 \
7
- python3-pip python3-distutils python3-dev python3-venv \
8
- git \
9
- ffmpeg \
10
- sudo wget curl software-properties-common build-essential gcc g++ \
11
- && ln -s /usr/bin/python3 /usr/bin/python \
12
- && rm -rf /var/lib/apt/lists/*
13
-
14
- # Install Rust
15
- RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
16
- ENV PATH="/root/.cargo/bin:${PATH}"
17
-
18
- # Set compiler environment variables
19
- ENV CC=/usr/bin/gcc
20
- ENV CXX=/usr/bin/g++
21
-
22
- # Upgrade pip and install base Python dependencies
23
- RUN pip install --upgrade pip setuptools setuptools-rust torch
24
- RUN pip install flash-attn --no-build-isolation
25
-
26
- # Copy requirements and configuration files
27
- COPY requirements.txt .
28
  COPY dhwani_config.json .
29
 
30
- # Install Python dependencies
31
- RUN pip install --no-cache-dir -r requirements.txt
32
-
33
  # Create a directory for pre-downloaded models
34
- RUN mkdir -p /app/models
35
-
36
- # Define build argument for HF_TOKEN
37
- ARG HF_TOKEN_DOCKER
38
-
39
- # Set environment variable for the build process
40
- ENV HF_TOKEN=$HF_TOKEN_DOCKER
41
-
42
- # Copy and run the model download script
43
- COPY download_models.py .
44
- RUN python download_models.py
45
-
46
- # Copy application code
47
  COPY . .
48
 
49
  # Set up user
 
1
+ FROM slabstech/dhwani-server-base
2
  WORKDIR /app
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  COPY dhwani_config.json .
5
 
 
 
 
6
  # Create a directory for pre-downloaded models
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  COPY . .
8
 
9
  # Set up user
Dockerfile.app ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the pre-built image with models as the base
2
+ FROM slabstech/dhwani-model-server:latest
3
+ WORKDIR /app
4
+
5
+ # Copy application code
6
+ COPY . .
7
+
8
+ # Set up user
9
+ RUN useradd -ms /bin/bash appuser \
10
+ && chown -R appuser:appuser /app
11
+ USER appuser
12
+
13
+ # Expose port
14
+ EXPOSE 7860
15
+
16
+ # Start the server
17
+ CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
Dockerfile.base ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image with CUDA support
2
+ FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 AS model-downloader
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ python3 \
8
+ python3-pip python3-distutils python3-dev python3-venv \
9
+ git \
10
+ ffmpeg \
11
+ sudo wget curl software-properties-common build-essential gcc g++ \
12
+ && ln -s /usr/bin/python3 /usr/bin/python \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Install Rust
16
+ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
17
+ ENV PATH="/root/.cargo/bin:${PATH}"
18
+
19
+ # Set compiler environment variables
20
+ ENV CC=/usr/bin/gcc
21
+ ENV CXX=/usr/bin/g++
22
+
23
+ # Upgrade pip and install base Python dependencies
24
+ RUN pip install --upgrade pip setuptools setuptools-rust torch
25
+ RUN pip install flash-attn --no-build-isolation
26
+
27
+ # Copy requirements and configuration files
28
+ COPY requirements.txt .
29
+ COPY dhwani_config.json .
30
+
31
+ # Install Python dependencies
32
+ RUN pip install --no-cache-dir -r requirements.txt
Dockerfile.models ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image with CUDA support
2
+ FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 AS model-downloader
3
+
4
+
5
+ # Create a directory for pre-downloaded models
6
+ RUN mkdir -p /app/models
7
+
8
+ # Define build argument for HF_TOKEN
9
+ ARG HF_TOKEN_DOCKER
10
+
11
+ # Set environment variable for the build process
12
+ ENV HF_TOKEN=$HF_TOKEN_DOCKER
13
+
14
+ # Copy and run the model download script
15
+ COPY download_models.py .
16
+ RUN python download_models.py
src/server/main.py CHANGED
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, AutoModel, BitsAndBytesConfig, Gemma3ForConditionalGeneration
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
@@ -91,10 +91,9 @@ class LLMManager:
91
  async def load(self):
92
  if not self.is_loaded:
93
  try:
94
- local_path = "/app/models/llm_model"
95
  self.model = await asyncio.to_thread(
96
  Gemma3ForConditionalGeneration.from_pretrained,
97
- local_path,
98
  device_map="auto",
99
  quantization_config=quantization_config,
100
  torch_dtype=self.torch_dtype
@@ -102,10 +101,10 @@ class LLMManager:
102
  self.model.eval()
103
  self.processor = await asyncio.to_thread(
104
  AutoProcessor.from_pretrained,
105
- local_path
106
  )
107
  self.is_loaded = True
108
- logger.info(f"LLM loaded from {local_path} on {self.device}")
109
  except Exception as e:
110
  logger.error(f"Failed to load LLM: {str(e)}")
111
  raise
@@ -269,15 +268,14 @@ class TTSManager:
269
 
270
  async def load(self):
271
  if not self.model:
272
- logger.info("Loading TTS model from local path asynchronously...")
273
- local_path = "/app/models/tts_model"
274
  self.model = await asyncio.to_thread(
275
  AutoModel.from_pretrained,
276
- local_path,
277
  trust_remote_code=True
278
  )
279
  self.model = self.model.to(self.device_type)
280
- logger.info("TTS model loaded from local path asynchronously")
281
 
282
  def synthesize(self, text, ref_audio_path, ref_text):
283
  if not self.model:
@@ -364,29 +362,29 @@ class TranslateManager:
364
  async def load(self):
365
  if not self.tokenizer or not self.model:
366
  if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
367
- local_path = "/app/models/trans_en_indic"
368
  elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
369
- local_path = "/app/models/trans_indic_en"
370
  elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
371
- local_path = "/app/models/trans_indic_indic"
372
  else:
373
  raise ValueError("Invalid language combination")
374
 
375
  self.tokenizer = await asyncio.to_thread(
376
  AutoTokenizer.from_pretrained,
377
- local_path,
378
  trust_remote_code=True
379
  )
380
  self.model = await asyncio.to_thread(
381
  AutoModelForSeq2SeqLM.from_pretrained,
382
- local_path,
383
  trust_remote_code=True,
384
  torch_dtype=torch.float16,
385
  attn_implementation="flash_attention_2"
386
  )
387
  self.model = self.model.to(self.device_type)
388
  self.model = torch.compile(self.model, mode="reduce-overhead")
389
- logger.info(f"Translation model loaded from {local_path} asynchronously")
390
 
391
  class ModelManager:
392
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
@@ -396,11 +394,11 @@ class ModelManager:
396
  self.is_lazy_loading = is_lazy_loading
397
 
398
  async def load_model(self, src_lang, tgt_lang, key):
399
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} from local path")
400
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
401
  await translate_manager.load()
402
  self.models[key] = translate_manager
403
- logger.info(f"Loaded translation model for {key} from local path")
404
 
405
  def get_model(self, src_lang, tgt_lang):
406
  key = self._get_model_key(src_lang, tgt_lang)
@@ -429,15 +427,14 @@ class ASRModelManager:
429
 
430
  async def load(self):
431
  if not self.model:
432
- logger.info("Loading ASR model from local path asynchronously...")
433
- local_path = "/app/models/asr_model"
434
  self.model = await asyncio.to_thread(
435
  AutoModel.from_pretrained,
436
- local_path,
437
  trust_remote_code=True
438
  )
439
  self.model = self.model.to(self.device_type)
440
- logger.info("ASR model loaded from local path asynchronously")
441
 
442
  # Global Managers
443
  llm_manager = LLMManager(settings.llm_model_name)
@@ -508,12 +505,12 @@ async def lifespan(app: FastAPI):
508
  translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
509
 
510
  await asyncio.gather(*tasks, *translation_tasks)
511
- logger.info("All models loaded successfully from local paths")
512
  except Exception as e:
513
  logger.error(f"Error loading models: {str(e)}")
514
  raise
515
 
516
- logger.info("Starting asynchronous model loading from local paths...")
517
  await load_all_models()
518
  yield
519
  llm_manager.unload()
 
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
 
91
  async def load(self):
92
  if not self.is_loaded:
93
  try:
 
94
  self.model = await asyncio.to_thread(
95
  Gemma3ForConditionalGeneration.from_pretrained,
96
+ self.model_name,
97
  device_map="auto",
98
  quantization_config=quantization_config,
99
  torch_dtype=self.torch_dtype
 
101
  self.model.eval()
102
  self.processor = await asyncio.to_thread(
103
  AutoProcessor.from_pretrained,
104
+ self.model_name
105
  )
106
  self.is_loaded = True
107
+ logger.info(f"LLM {self.model_name} loaded asynchronously on {self.device}")
108
  except Exception as e:
109
  logger.error(f"Failed to load LLM: {str(e)}")
110
  raise
 
268
 
269
  async def load(self):
270
  if not self.model:
271
+ logger.info("Loading TTS model IndicF5 asynchronously...")
 
272
  self.model = await asyncio.to_thread(
273
  AutoModel.from_pretrained,
274
+ self.repo_id,
275
  trust_remote_code=True
276
  )
277
  self.model = self.model.to(self.device_type)
278
+ logger.info("TTS model IndicF5 loaded asynchronously")
279
 
280
  def synthesize(self, text, ref_audio_path, ref_text):
281
  if not self.model:
 
362
  async def load(self):
363
  if not self.tokenizer or not self.model:
364
  if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
365
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
366
  elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
367
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
368
  elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
369
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
370
  else:
371
  raise ValueError("Invalid language combination")
372
 
373
  self.tokenizer = await asyncio.to_thread(
374
  AutoTokenizer.from_pretrained,
375
+ model_name,
376
  trust_remote_code=True
377
  )
378
  self.model = await asyncio.to_thread(
379
  AutoModelForSeq2SeqLM.from_pretrained,
380
+ model_name,
381
  trust_remote_code=True,
382
  torch_dtype=torch.float16,
383
  attn_implementation="flash_attention_2"
384
  )
385
  self.model = self.model.to(self.device_type)
386
  self.model = torch.compile(self.model, mode="reduce-overhead")
387
+ logger.info(f"Translation model {model_name} loaded asynchronously")
388
 
389
  class ModelManager:
390
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
 
394
  self.is_lazy_loading = is_lazy_loading
395
 
396
  async def load_model(self, src_lang, tgt_lang, key):
397
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} asynchronously")
398
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
399
  await translate_manager.load()
400
  self.models[key] = translate_manager
401
+ logger.info(f"Loaded translation model for {key} asynchronously")
402
 
403
  def get_model(self, src_lang, tgt_lang):
404
  key = self._get_model_key(src_lang, tgt_lang)
 
427
 
428
  async def load(self):
429
  if not self.model:
430
+ logger.info("Loading ASR model asynchronously...")
 
431
  self.model = await asyncio.to_thread(
432
  AutoModel.from_pretrained,
433
+ "ai4bharat/indic-conformer-600m-multilingual",
434
  trust_remote_code=True
435
  )
436
  self.model = self.model.to(self.device_type)
437
+ logger.info("ASR model loaded asynchronously")
438
 
439
  # Global Managers
440
  llm_manager = LLMManager(settings.llm_model_name)
 
505
  translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
506
 
507
  await asyncio.gather(*tasks, *translation_tasks)
508
+ logger.info("All models loaded successfully asynchronously")
509
  except Exception as e:
510
  logger.error(f"Error loading models: {str(e)}")
511
  raise
512
 
513
+ logger.info("Starting asynchronous model loading...")
514
  await load_all_models()
515
  yield
516
  llm_manager.unload()
src/server/main_local.py ADDED
@@ -0,0 +1,913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ from time import time
5
+ from typing import List
6
+ import tempfile
7
+ import uvicorn
8
+ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
+ from PIL import Image
12
+ from pydantic import BaseModel, field_validator
13
+ from pydantic_settings import BaseSettings
14
+ from slowapi import Limiter
15
+ from slowapi.util import get_remote_address
16
+ import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, AutoModel, BitsAndBytesConfig, Gemma3ForConditionalGeneration
18
+ from IndicTransToolkit import IndicProcessor
19
+ import json
20
+ import asyncio
21
+ from contextlib import asynccontextmanager
22
+ import soundfile as sf
23
+ import numpy as np
24
+ import requests
25
+ from starlette.responses import StreamingResponse
26
+ from logging_config import logger
27
+ from tts_config import SPEED, ResponseFormat, config as tts_config
28
+ import torchaudio
29
+
30
+ # Device setup
31
+ if torch.cuda.is_available():
32
+ device = "cuda:0"
33
+ logger.info("GPU will be used for inference")
34
+ else:
35
+ device = "cpu"
36
+ logger.info("CPU will be used for inference")
37
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
+
39
+ # Check CUDA availability and version
40
+ cuda_available = torch.cuda.is_available()
41
+ cuda_version = torch.version.cuda if cuda_available else None
42
+
43
+ if torch.cuda.is_available():
44
+ device_idx = torch.cuda.current_device()
45
+ capability = torch.cuda.get_device_capability(device_idx)
46
+ compute_capability_float = float(f"{capability[0]}.{capability[1]}")
47
+ print(f"CUDA version: {cuda_version}")
48
+ print(f"CUDA Compute Capability: {compute_capability_float}")
49
+ else:
50
+ print("CUDA is not available on this system.")
51
+
52
+ # Settings
53
+ class Settings(BaseSettings):
54
+ llm_model_name: str = "google/gemma-3-4b-it"
55
+ max_tokens: int = 512
56
+ host: str = "0.0.0.0"
57
+ port: int = 7860
58
+ chat_rate_limit: str = "100/minute"
59
+ speech_rate_limit: str = "5/minute"
60
+
61
+ @field_validator("chat_rate_limit", "speech_rate_limit")
62
+ def validate_rate_limit(cls, v):
63
+ if not v.count("/") == 1 or not v.split("/")[0].isdigit():
64
+ raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
65
+ return v
66
+
67
+ class Config:
68
+ env_file = ".env"
69
+
70
+ settings = Settings()
71
+
72
+ # Quantization config for LLM
73
+ quantization_config = BitsAndBytesConfig(
74
+ load_in_4bit=True,
75
+ bnb_4bit_quant_type="nf4",
76
+ bnb_4bit_use_double_quant=True,
77
+ bnb_4bit_compute_dtype=torch.bfloat16
78
+ )
79
+
80
+ # LLM Manager
81
+ class LLMManager:
82
+ def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
+ self.model_name = model_name
84
+ self.device = torch.device(device)
85
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
+ self.model = None
87
+ self.processor = None
88
+ self.is_loaded = False
89
+ logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
+
91
+ async def load(self):
92
+ if not self.is_loaded:
93
+ try:
94
+ local_path = "/app/models/llm_model"
95
+ self.model = await asyncio.to_thread(
96
+ Gemma3ForConditionalGeneration.from_pretrained,
97
+ local_path,
98
+ device_map="auto",
99
+ quantization_config=quantization_config,
100
+ torch_dtype=self.torch_dtype
101
+ )
102
+ self.model.eval()
103
+ self.processor = await asyncio.to_thread(
104
+ AutoProcessor.from_pretrained,
105
+ local_path
106
+ )
107
+ self.is_loaded = True
108
+ logger.info(f"LLM loaded from {local_path} on {self.device}")
109
+ except Exception as e:
110
+ logger.error(f"Failed to load LLM: {str(e)}")
111
+ raise
112
+
113
+ def unload(self):
114
+ if self.is_loaded:
115
+ del self.model
116
+ del self.processor
117
+ if self.device.type == "cuda":
118
+ torch.cuda.empty_cache()
119
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
120
+ self.is_loaded = False
121
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
122
+
123
+ async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
124
+ if not self.is_loaded:
125
+ await self.load()
126
+
127
+ messages_vlm = [
128
+ {
129
+ "role": "system",
130
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
131
+ },
132
+ {
133
+ "role": "user",
134
+ "content": [{"type": "text", "text": prompt}]
135
+ }
136
+ ]
137
+
138
+ try:
139
+ inputs_vlm = self.processor.apply_chat_template(
140
+ messages_vlm,
141
+ add_generation_prompt=True,
142
+ tokenize=True,
143
+ return_dict=True,
144
+ return_tensors="pt"
145
+ ).to(self.device, dtype=torch.bfloat16)
146
+ except Exception as e:
147
+ logger.error(f"Error in tokenization: {str(e)}")
148
+ raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
149
+
150
+ input_len = inputs_vlm["input_ids"].shape[-1]
151
+
152
+ with torch.inference_mode():
153
+ generation = self.model.generate(
154
+ **inputs_vlm,
155
+ max_new_tokens=max_tokens,
156
+ do_sample=True,
157
+ temperature=temperature
158
+ )
159
+ generation = generation[0][input_len:]
160
+
161
+ response = self.processor.decode(generation, skip_special_tokens=True)
162
+ logger.info(f"Generated response: {response}")
163
+ return response
164
+
165
+ async def vision_query(self, image: Image.Image, query: str) -> str:
166
+ if not self.is_loaded:
167
+ await self.load()
168
+
169
+ messages_vlm = [
170
+ {
171
+ "role": "system",
172
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
173
+ },
174
+ {
175
+ "role": "user",
176
+ "content": []
177
+ }
178
+ ]
179
+
180
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
181
+ if image and image.size[0] > 0 and image.size[1] > 0:
182
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
183
+ logger.info(f"Received valid image for processing")
184
+ else:
185
+ logger.info("No valid image provided, processing text only")
186
+
187
+ try:
188
+ inputs_vlm = self.processor.apply_chat_template(
189
+ messages_vlm,
190
+ add_generation_prompt=True,
191
+ tokenize=True,
192
+ return_dict=True,
193
+ return_tensors="pt"
194
+ ).to(self.device, dtype=torch.bfloat16)
195
+ except Exception as e:
196
+ logger.error(f"Error in apply_chat_template: {str(e)}")
197
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
198
+
199
+ input_len = inputs_vlm["input_ids"].shape[-1]
200
+
201
+ with torch.inference_mode():
202
+ generation = self.model.generate(
203
+ **inputs_vlm,
204
+ max_new_tokens=512,
205
+ do_sample=True,
206
+ temperature=0.7
207
+ )
208
+ generation = generation[0][input_len:]
209
+
210
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
211
+ logger.info(f"Vision query response: {decoded}")
212
+ return decoded
213
+
214
+ async def chat_v2(self, image: Image.Image, query: str) -> str:
215
+ if not self.is_loaded:
216
+ await self.load()
217
+
218
+ messages_vlm = [
219
+ {
220
+ "role": "system",
221
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
222
+ },
223
+ {
224
+ "role": "user",
225
+ "content": []
226
+ }
227
+ ]
228
+
229
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
230
+ if image and image.size[0] > 0 and image.size[1] > 0:
231
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
232
+ logger.info(f"Received valid image for processing")
233
+ else:
234
+ logger.info("No valid image provided, processing text only")
235
+
236
+ try:
237
+ inputs_vlm = self.processor.apply_chat_template(
238
+ messages_vlm,
239
+ add_generation_prompt=True,
240
+ tokenize=True,
241
+ return_dict=True,
242
+ return_tensors="pt"
243
+ ).to(self.device, dtype=torch.bfloat16)
244
+ except Exception as e:
245
+ logger.error(f"Error in apply_chat_template: {str(e)}")
246
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
247
+
248
+ input_len = inputs_vlm["input_ids"].shape[-1]
249
+
250
+ with torch.inference_mode():
251
+ generation = self.model.generate(
252
+ **inputs_vlm,
253
+ max_new_tokens=512,
254
+ do_sample=True,
255
+ temperature=0.7
256
+ )
257
+ generation = generation[0][input_len:]
258
+
259
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
260
+ logger.info(f"Chat_v2 response: {decoded}")
261
+ return decoded
262
+
263
+ # TTS Manager
264
+ class TTSManager:
265
+ def __init__(self, device_type=device):
266
+ self.device_type = device_type
267
+ self.model = None
268
+ self.repo_id = "ai4bharat/IndicF5"
269
+
270
+ async def load(self):
271
+ if not self.model:
272
+ logger.info("Loading TTS model from local path asynchronously...")
273
+ local_path = "/app/models/tts_model"
274
+ self.model = await asyncio.to_thread(
275
+ AutoModel.from_pretrained,
276
+ local_path,
277
+ trust_remote_code=True
278
+ )
279
+ self.model = self.model.to(self.device_type)
280
+ logger.info("TTS model loaded from local path asynchronously")
281
+
282
+ def synthesize(self, text, ref_audio_path, ref_text):
283
+ if not self.model:
284
+ raise ValueError("TTS model not loaded")
285
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
286
+
287
+ # TTS Constants
288
+ EXAMPLES = [
289
+ {
290
+ "audio_name": "KAN_F (Happy)",
291
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
292
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
293
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
294
+ },
295
+ ]
296
+
297
+ # Pydantic models for TTS
298
+ class SynthesizeRequest(BaseModel):
299
+ text: str
300
+ ref_audio_name: str
301
+ ref_text: str = None
302
+
303
+ class KannadaSynthesizeRequest(BaseModel):
304
+ text: str
305
+
306
+ # TTS Functions
307
+ def load_audio_from_url(url: str):
308
+ response = requests.get(url)
309
+ if response.status_code == 200:
310
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
311
+ return sample_rate, audio_data
312
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
313
+
314
+ def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
315
+ ref_audio_url = None
316
+ for example in EXAMPLES:
317
+ if example["audio_name"] == ref_audio_name:
318
+ ref_audio_url = example["audio_url"]
319
+ if not ref_text:
320
+ ref_text = example["ref_text"]
321
+ break
322
+
323
+ if not ref_audio_url:
324
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
325
+ if not text.strip():
326
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
327
+ if not ref_text or not ref_text.strip():
328
+ raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
329
+
330
+ sample_rate, audio_data = load_audio_from_url(ref_audio_url)
331
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
332
+ sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
333
+ temp_audio.flush()
334
+ audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
335
+
336
+ if audio.dtype == np.int16:
337
+ audio = audio.astype(np.float32) / 32768.0
338
+ buffer = io.BytesIO()
339
+ sf.write(buffer, audio, 24000, format='WAV')
340
+ buffer.seek(0)
341
+ return buffer
342
+
343
+ # Supported languages
344
+ SUPPORTED_LANGUAGES = {
345
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
346
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
347
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
348
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
349
+ "kan_Knda", "ory_Orya",
350
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
351
+ "por_Latn", "rus_Cyrl", "pol_Latn"
352
+ }
353
+
354
+ # Translation Manager
355
+ class TranslateManager:
356
+ def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
357
+ self.device_type = device_type
358
+ self.tokenizer = None
359
+ self.model = None
360
+ self.src_lang = src_lang
361
+ self.tgt_lang = tgt_lang
362
+ self.use_distilled = use_distilled
363
+
364
+ async def load(self):
365
+ if not self.tokenizer or not self.model:
366
+ if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
367
+ local_path = "/app/models/trans_en_indic"
368
+ elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
369
+ local_path = "/app/models/trans_indic_en"
370
+ elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
371
+ local_path = "/app/models/trans_indic_indic"
372
+ else:
373
+ raise ValueError("Invalid language combination")
374
+
375
+ self.tokenizer = await asyncio.to_thread(
376
+ AutoTokenizer.from_pretrained,
377
+ local_path,
378
+ trust_remote_code=True
379
+ )
380
+ self.model = await asyncio.to_thread(
381
+ AutoModelForSeq2SeqLM.from_pretrained,
382
+ local_path,
383
+ trust_remote_code=True,
384
+ torch_dtype=torch.float16,
385
+ attn_implementation="flash_attention_2"
386
+ )
387
+ self.model = self.model.to(self.device_type)
388
+ self.model = torch.compile(self.model, mode="reduce-overhead")
389
+ logger.info(f"Translation model loaded from {local_path} asynchronously")
390
+
391
+ class ModelManager:
392
+ def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
393
+ self.models = {}
394
+ self.device_type = device_type
395
+ self.use_distilled = use_distilled
396
+ self.is_lazy_loading = is_lazy_loading
397
+
398
+ async def load_model(self, src_lang, tgt_lang, key):
399
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} from local path")
400
+ translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
401
+ await translate_manager.load()
402
+ self.models[key] = translate_manager
403
+ logger.info(f"Loaded translation model for {key} from local path")
404
+
405
+ def get_model(self, src_lang, tgt_lang):
406
+ key = self._get_model_key(src_lang, tgt_lang)
407
+ if key not in self.models:
408
+ if self.is_lazy_loading:
409
+ asyncio.create_task(self.load_model(src_lang, tgt_lang, key))
410
+ else:
411
+ raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
412
+ return self.models.get(key)
413
+
414
+ def _get_model_key(self, src_lang, tgt_lang):
415
+ if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
416
+ return 'eng_indic'
417
+ elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
418
+ return 'indic_eng'
419
+ elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
420
+ return 'indic_indic'
421
+ raise ValueError("Invalid language combination")
422
+
423
+ # ASR Manager
424
+ class ASRModelManager:
425
+ def __init__(self, device_type="cuda"):
426
+ self.device_type = device_type
427
+ self.model = None
428
+ self.model_language = {"kannada": "kn"}
429
+
430
+ async def load(self):
431
+ if not self.model:
432
+ logger.info("Loading ASR model from local path asynchronously...")
433
+ local_path = "/app/models/asr_model"
434
+ self.model = await asyncio.to_thread(
435
+ AutoModel.from_pretrained,
436
+ local_path,
437
+ trust_remote_code=True
438
+ )
439
+ self.model = self.model.to(self.device_type)
440
+ logger.info("ASR model loaded from local path asynchronously")
441
+
442
+ # Global Managers
443
+ llm_manager = LLMManager(settings.llm_model_name)
444
+ model_manager = ModelManager()
445
+ asr_manager = ASRModelManager()
446
+ tts_manager = TTSManager()
447
+ ip = IndicProcessor(inference=True)
448
+
449
+ # Pydantic Models
450
+ class ChatRequest(BaseModel):
451
+ prompt: str
452
+ src_lang: str = "kan_Knda"
453
+ tgt_lang: str = "kan_Knda"
454
+
455
+ @field_validator("prompt")
456
+ def prompt_must_be_valid(cls, v):
457
+ if len(v) > 1000:
458
+ raise ValueError("Prompt cannot exceed 1000 characters")
459
+ return v.strip()
460
+
461
+ @field_validator("src_lang", "tgt_lang")
462
+ def validate_language(cls, v):
463
+ if v not in SUPPORTED_LANGUAGES:
464
+ raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
465
+ return v
466
+
467
+ class ChatResponse(BaseModel):
468
+ response: str
469
+
470
+ class TranslationRequest(BaseModel):
471
+ sentences: List[str]
472
+ src_lang: str
473
+ tgt_lang: str
474
+
475
+ class TranscriptionResponse(BaseModel):
476
+ text: str
477
+
478
+ class TranslationResponse(BaseModel):
479
+ translations: List[str]
480
+
481
+ # Dependency
482
+ def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
483
+ return model_manager.get_model(src_lang, tgt_lang)
484
+
485
+ # Lifespan Event Handler
486
+ translation_configs = []
487
+
488
+ @asynccontextmanager
489
+ async def lifespan(app: FastAPI):
490
+ async def load_all_models():
491
+ try:
492
+ tasks = [
493
+ llm_manager.load(),
494
+ tts_manager.load(),
495
+ asr_manager.load(),
496
+ ]
497
+
498
+ translation_tasks = [
499
+ model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic'),
500
+ model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng'),
501
+ model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic'),
502
+ ]
503
+
504
+ for config in translation_configs:
505
+ src_lang = config["src_lang"]
506
+ tgt_lang = config["tgt_lang"]
507
+ key = model_manager._get_model_key(src_lang, tgt_lang)
508
+ translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
509
+
510
+ await asyncio.gather(*tasks, *translation_tasks)
511
+ logger.info("All models loaded successfully from local paths")
512
+ except Exception as e:
513
+ logger.error(f"Error loading models: {str(e)}")
514
+ raise
515
+
516
+ logger.info("Starting asynchronous model loading from local paths...")
517
+ await load_all_models()
518
+ yield
519
+ llm_manager.unload()
520
+ logger.info("Server shutdown complete")
521
+
522
+ # FastAPI App
523
+ app = FastAPI(
524
+ title="Dhwani API",
525
+ description="AI Chat API supporting Indian languages",
526
+ version="1.0.0",
527
+ redirect_slashes=False,
528
+ lifespan=lifespan
529
+ )
530
+
531
+ app.add_middleware(
532
+ CORSMiddleware,
533
+ allow_origins=["*"],
534
+ allow_credentials=False,
535
+ allow_methods=["*"],
536
+ allow_headers=["*"],
537
+ )
538
+
539
+ limiter = Limiter(key_func=get_remote_address)
540
+ app.state.limiter = limiter
541
+
542
+ # API Endpoints
543
+ @app.post("/audio/speech", response_class=StreamingResponse)
544
+ async def synthesize_kannada(request: KannadaSynthesizeRequest):
545
+ if not tts_manager.model:
546
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
547
+ kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
548
+ if not request.text.strip():
549
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
550
+
551
+ audio_buffer = synthesize_speech(
552
+ tts_manager,
553
+ text=request.text,
554
+ ref_audio_name="KAN_F (Happy)",
555
+ ref_text=kannada_example["ref_text"]
556
+ )
557
+
558
+ return StreamingResponse(
559
+ audio_buffer,
560
+ media_type="audio/wav",
561
+ headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
562
+ )
563
+
564
+ @app.post("/translate", response_model=TranslationResponse)
565
+ async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
566
+ input_sentences = request.sentences
567
+ src_lang = request.src_lang
568
+ tgt_lang = request.tgt_lang
569
+
570
+ if not input_sentences:
571
+ raise HTTPException(status_code=400, detail="Input sentences are required")
572
+
573
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
574
+ inputs = translate_manager.tokenizer(
575
+ batch,
576
+ truncation=True,
577
+ padding="longest",
578
+ return_tensors="pt",
579
+ return_attention_mask=True,
580
+ ).to(translate_manager.device_type)
581
+
582
+ with torch.no_grad():
583
+ generated_tokens = translate_manager.model.generate(
584
+ **inputs,
585
+ use_cache=True,
586
+ min_length=0,
587
+ max_length=256,
588
+ num_beams=5,
589
+ num_return_sequences=1,
590
+ )
591
+
592
+ with translate_manager.tokenizer.as_target_tokenizer():
593
+ generated_tokens = translate_manager.tokenizer.batch_decode(
594
+ generated_tokens.detach().cpu().tolist(),
595
+ skip_special_tokens=True,
596
+ clean_up_tokenization_spaces=True,
597
+ )
598
+
599
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
600
+ return TranslationResponse(translations=translations)
601
+
602
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
603
+ try:
604
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
605
+ except ValueError as e:
606
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
607
+ key = model_manager._get_model_key(src_lang, tgt_lang)
608
+ await model_manager.load_model(src_lang, tgt_lang, key)
609
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
610
+
611
+ if not translate_manager.model:
612
+ await translate_manager.load()
613
+
614
+ request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
615
+ response = await translate(request, translate_manager)
616
+ return response.translations
617
+
618
+ @app.get("/v1/health")
619
+ async def health_check():
620
+ return {"status": "healthy", "model": settings.llm_model_name}
621
+
622
+ @app.get("/")
623
+ async def home():
624
+ return RedirectResponse(url="/docs")
625
+
626
+ @app.post("/v1/unload_all_models")
627
+ async def unload_all_models():
628
+ try:
629
+ logger.info("Starting to unload all models...")
630
+ llm_manager.unload()
631
+ logger.info("All models unloaded successfully")
632
+ return {"status": "success", "message": "All models unloaded"}
633
+ except Exception as e:
634
+ logger.error(f"Error unloading models: {str(e)}")
635
+ raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
636
+
637
+ @app.post("/v1/load_all_models")
638
+ async def load_all_models():
639
+ try:
640
+ logger.info("Starting to load all models...")
641
+ await llm_manager.load()
642
+ logger.info("All models loaded successfully")
643
+ return {"status": "success", "message": "All models loaded"}
644
+ except Exception as e:
645
+ logger.error(f"Error loading models: {str(e)}")
646
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
647
+
648
+ @app.post("/v1/translate", response_model=TranslationResponse)
649
+ async def translate_endpoint(request: TranslationRequest):
650
+ logger.info(f"Received translation request: {request.dict()}")
651
+ try:
652
+ translations = await perform_internal_translation(
653
+ sentences=request.sentences,
654
+ src_lang=request.src_lang,
655
+ tgt_lang=request.tgt_lang
656
+ )
657
+ logger.info(f"Translation successful: {translations}")
658
+ return TranslationResponse(translations=translations)
659
+ except Exception as e:
660
+ logger.error(f"Unexpected error during translation: {str(e)}")
661
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
662
+
663
+ @app.post("/v1/chat", response_model=ChatResponse)
664
+ @limiter.limit(settings.chat_rate_limit)
665
+ async def chat(request: Request, chat_request: ChatRequest):
666
+ if not chat_request.prompt:
667
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
668
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
669
+
670
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
671
+
672
+ try:
673
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
674
+ translated_prompt = await perform_internal_translation(
675
+ sentences=[chat_request.prompt],
676
+ src_lang=chat_request.src_lang,
677
+ tgt_lang="eng_Latn"
678
+ )
679
+ prompt_to_process = translated_prompt[0]
680
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
681
+ else:
682
+ prompt_to_process = chat_request.prompt
683
+ logger.info("Prompt in English or European language, no translation needed")
684
+
685
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
686
+ logger.info(f"Generated response: {response}")
687
+
688
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
689
+ translated_response = await perform_internal_translation(
690
+ sentences=[response],
691
+ src_lang="eng_Latn",
692
+ tgt_lang=chat_request.tgt_lang
693
+ )
694
+ final_response = translated_response[0]
695
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
696
+ else:
697
+ final_response = response
698
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
699
+
700
+ return ChatResponse(response=final_response)
701
+ except Exception as e:
702
+ logger.error(f"Error processing request: {str(e)}")
703
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
704
+
705
+ @app.post("/v1/visual_query/")
706
+ async def visual_query(
707
+ file: UploadFile = File(...),
708
+ query: str = Body(...),
709
+ src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
710
+ tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
711
+ ):
712
+ try:
713
+ image = Image.open(file.file)
714
+ if image.size == (0, 0):
715
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
716
+
717
+ if src_lang != "eng_Latn":
718
+ translated_query = await perform_internal_translation(
719
+ sentences=[query],
720
+ src_lang=src_lang,
721
+ tgt_lang="eng_Latn"
722
+ )
723
+ query_to_process = translated_query[0]
724
+ logger.info(f"Translated query to English: {query_to_process}")
725
+ else:
726
+ query_to_process = query
727
+ logger.info("Query already in English, no translation needed")
728
+
729
+ answer = await llm_manager.vision_query(image, query_to_process)
730
+ logger.info(f"Generated English answer: {answer}")
731
+
732
+ if tgt_lang != "eng_Latn":
733
+ translated_answer = await perform_internal_translation(
734
+ sentences=[answer],
735
+ src_lang="eng_Latn",
736
+ tgt_lang=tgt_lang
737
+ )
738
+ final_answer = translated_answer[0]
739
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
740
+ else:
741
+ final_answer = answer
742
+ logger.info("Answer kept in English, no translation needed")
743
+
744
+ return {"answer": final_answer}
745
+ except Exception as e:
746
+ logger.error(f"Error processing request: {str(e)}")
747
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
748
+
749
+ @app.post("/v1/chat_v2", response_model=ChatResponse)
750
+ @limiter.limit(settings.chat_rate_limit)
751
+ async def chat_v2(
752
+ request: Request,
753
+ prompt: str = Form(...),
754
+ image: UploadFile = File(default=None),
755
+ src_lang: str = Form("kan_Knda"),
756
+ tgt_lang: str = Form("kan_Knda"),
757
+ ):
758
+ if not prompt:
759
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
760
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
761
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
762
+
763
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
764
+
765
+ try:
766
+ if image:
767
+ image_data = await image.read()
768
+ if not image_data:
769
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
770
+ img = Image.open(io.BytesIO(image_data))
771
+
772
+ if src_lang != "eng_Latn":
773
+ translated_prompt = await perform_internal_translation(
774
+ sentences=[prompt],
775
+ src_lang=src_lang,
776
+ tgt_lang="eng_Latn"
777
+ )
778
+ prompt_to_process = translated_prompt[0]
779
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
780
+ else:
781
+ prompt_to_process = prompt
782
+ logger.info("Prompt already in English, no translation needed")
783
+
784
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
785
+ logger.info(f"Generated English response: {decoded}")
786
+
787
+ if tgt_lang != "eng_Latn":
788
+ translated_response = await perform_internal_translation(
789
+ sentences=[decoded],
790
+ src_lang="eng_Latn",
791
+ tgt_lang=tgt_lang
792
+ )
793
+ final_response = translated_response[0]
794
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
795
+ else:
796
+ final_response = decoded
797
+ logger.info("Response kept in English, no translation needed")
798
+ else:
799
+ if src_lang != "eng_Latn":
800
+ translated_prompt = await perform_internal_translation(
801
+ sentences=[prompt],
802
+ src_lang=src_lang,
803
+ tgt_lang="eng_Latn"
804
+ )
805
+ prompt_to_process = translated_prompt[0]
806
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
807
+ else:
808
+ prompt_to_process = prompt
809
+ logger.info("Prompt already in English, no translation needed")
810
+
811
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
812
+ logger.info(f"Generated English response: {decoded}")
813
+
814
+ if tgt_lang != "eng_Latn":
815
+ translated_response = await perform_internal_translation(
816
+ sentences=[decoded],
817
+ src_lang="eng_Latn",
818
+ tgt_lang=tgt_lang
819
+ )
820
+ final_response = translated_response[0]
821
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
822
+ else:
823
+ final_response = decoded
824
+ logger.info("Response kept in English, no translation needed")
825
+
826
+ return ChatResponse(response=final_response)
827
+ except Exception as e:
828
+ logger.error(f"Error processing request: {str(e)}")
829
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
830
+
831
+ @app.post("/transcribe/", response_model=TranscriptionResponse)
832
+ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
833
+ if not asr_manager.model:
834
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
835
+ try:
836
+ wav, sr = torchaudio.load(file.file)
837
+ wav = torch.mean(wav, dim=0, keepdim=True)
838
+ target_sample_rate = 16000
839
+ if sr != target_sample_rate:
840
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
841
+ wav = resampler(wav)
842
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
843
+ return TranscriptionResponse(text=transcription_rnnt)
844
+ except Exception as e:
845
+ logger.error(f"Error in transcription: {str(e)}")
846
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
847
+
848
+ @app.post("/v1/speech_to_speech")
849
+ async def speech_to_speech(
850
+ request: Request,
851
+ file: UploadFile = File(...),
852
+ language: str = Query(..., enum=list(asr_manager.model_language.keys())),
853
+ ) -> StreamingResponse:
854
+ if not tts_manager.model:
855
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
856
+ transcription = await transcribe_audio(file, language)
857
+ logger.info(f"Transcribed text: {transcription.text}")
858
+
859
+ chat_request = ChatRequest(
860
+ prompt=transcription.text,
861
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
862
+ tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
863
+ )
864
+ processed_text = await chat(request, chat_request)
865
+ logger.info(f"Processed text: {processed_text.response}")
866
+
867
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
868
+ audio_response = await synthesize_kannada(voice_request)
869
+ return audio_response
870
+
871
+ LANGUAGE_TO_SCRIPT = {
872
+ "kannada": "kan_Knda"
873
+ }
874
+
875
+ # Main Execution
876
+ if __name__ == "__main__":
877
+ parser = argparse.ArgumentParser(description="Run the FastAPI server.")
878
+ parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
879
+ parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
880
+ parser.add_argument("--config", type=str, default="config_one", help="Configuration to use")
881
+ args = parser.parse_args()
882
+
883
+ def load_config(config_path="dhwani_config.json"):
884
+ with open(config_path, "r") as f:
885
+ return json.load(f)
886
+
887
+ config_data = load_config()
888
+ if args.config not in config_data["configs"]:
889
+ raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
890
+
891
+ selected_config = config_data["configs"][args.config]
892
+ global_settings = config_data["global_settings"]
893
+
894
+ settings.llm_model_name = selected_config["components"]["LLM"]["model"]
895
+ settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
896
+ settings.host = global_settings["host"]
897
+ settings.port = global_settings["port"]
898
+ settings.chat_rate_limit = global_settings["chat_rate_limit"]
899
+ settings.speech_rate_limit = global_settings["speech_rate_limit"]
900
+
901
+ llm_manager = LLMManager(settings.llm_model_name)
902
+
903
+ if selected_config["components"]["ASR"]:
904
+ asr_model_name = selected_config["components"]["ASR"]["model"]
905
+ asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
906
+
907
+ if selected_config["components"]["Translation"]:
908
+ translation_configs.extend(selected_config["components"]["Translation"])
909
+
910
+ host = args.host if args.host != settings.host else settings.host
911
+ port = args.port if args.port != settings.port else settings.port
912
+
913
+ uvicorn.run(app, host=host, port=port)