sachin commited on
Commit
564e070
·
1 Parent(s): 230a925

add- llm optimisation

Browse files
Files changed (2) hide show
  1. src/server/main.py +93 -55
  2. src/server/main_v0.py +929 -0
src/server/main.py CHANGED
@@ -79,31 +79,44 @@ quantization_config = BitsAndBytesConfig(
79
 
80
  # LLM Manager
81
  class LLMManager:
82
- def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
  self.model_name = model_name
84
  self.device = torch.device(device)
85
- self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
87
  self.processor = None
88
  self.is_loaded = False
 
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
  def load(self):
92
  if not self.is_loaded:
93
  try:
 
 
 
 
 
 
 
 
 
 
 
94
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
95
  self.model_name,
96
  device_map="auto",
97
  quantization_config=quantization_config,
98
- torch_dtype=self.torch_dtype
99
- )
100
- self.model.eval()
101
- self.processor = AutoProcessor.from_pretrained(self.model_name)
 
102
  self.is_loaded = True
103
- logger.info(f"LLM {self.model_name} loaded on {self.device}")
104
  except Exception as e:
105
  logger.error(f"Failed to load LLM: {str(e)}")
106
- raise
107
 
108
  def unload(self):
109
  if self.is_loaded:
@@ -113,12 +126,18 @@ class LLMManager:
113
  torch.cuda.empty_cache()
114
  logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
115
  self.is_loaded = False
 
116
  logger.info(f"LLM {self.model_name} unloaded from {self.device}")
117
 
118
- async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
120
  self.load()
121
 
 
 
 
 
 
122
  messages_vlm = [
123
  {
124
  "role": "system",
@@ -131,29 +150,37 @@ class LLMManager:
131
  ]
132
 
133
  try:
134
- inputs_vlm = self.processor.apply_chat_template(
135
- messages_vlm,
136
- add_generation_prompt=True,
137
- tokenize=True,
138
- return_dict=True,
139
- return_tensors="pt"
140
- ).to(self.device, dtype=torch.bfloat16)
 
 
 
 
 
141
  except Exception as e:
142
  logger.error(f"Error in tokenization: {str(e)}")
143
  raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
144
 
145
  input_len = inputs_vlm["input_ids"].shape[-1]
 
146
 
147
  with torch.inference_mode():
148
  generation = self.model.generate(
149
  **inputs_vlm,
150
- max_new_tokens=max_tokens,
151
  do_sample=True,
 
152
  temperature=temperature
153
  )
154
  generation = generation[0][input_len:]
155
 
156
  response = self.processor.decode(generation, skip_special_tokens=True)
 
157
  logger.info(f"Generated response: {response}")
158
  return response
159
 
@@ -164,47 +191,53 @@ class LLMManager:
164
  messages_vlm = [
165
  {
166
  "role": "system",
167
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
168
  },
169
  {
170
  "role": "user",
171
- "content": []
172
  }
173
  ]
174
 
175
- messages_vlm[1]["content"].append({"type": "text", "text": query})
176
- if image and image.size[0] > 0 and image.size[1] > 0:
177
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
178
- logger.info(f"Received valid image for processing")
179
- else:
180
- logger.info("No valid image provided, processing text only")
181
 
182
  try:
183
- inputs_vlm = self.processor.apply_chat_template(
184
- messages_vlm,
185
- add_generation_prompt=True,
186
- tokenize=True,
187
- return_dict=True,
188
- return_tensors="pt"
189
- ).to(self.device, dtype=torch.bfloat16)
 
 
 
 
 
190
  except Exception as e:
191
  logger.error(f"Error in apply_chat_template: {str(e)}")
192
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
193
 
194
  input_len = inputs_vlm["input_ids"].shape[-1]
 
195
 
196
  with torch.inference_mode():
197
  generation = self.model.generate(
198
  **inputs_vlm,
199
- max_new_tokens=512,
200
  do_sample=True,
 
201
  temperature=0.7
202
  )
203
  generation = generation[0][input_len:]
204
 
205
- decoded = self.processor.decode(generation, skip_special_tokens=True)
206
- logger.info(f"Vision query response: {decoded}")
207
- return decoded
 
208
 
209
  async def chat_v2(self, image: Image.Image, query: str) -> str:
210
  if not self.is_loaded:
@@ -217,43 +250,49 @@ class LLMManager:
217
  },
218
  {
219
  "role": "user",
220
- "content": []
221
  }
222
  ]
223
 
224
- messages_vlm[1]["content"].append({"type": "text", "text": query})
225
- if image and image.size[0] > 0 and image.size[1] > 0:
226
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
227
- logger.info(f"Received valid image for processing")
228
- else:
229
- logger.info("No valid image provided, processing text only")
230
 
231
  try:
232
- inputs_vlm = self.processor.apply_chat_template(
233
- messages_vlm,
234
- add_generation_prompt=True,
235
- tokenize=True,
236
- return_dict=True,
237
- return_tensors="pt"
238
- ).to(self.device, dtype=torch.bfloat16)
 
 
 
 
 
239
  except Exception as e:
240
  logger.error(f"Error in apply_chat_template: {str(e)}")
241
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
242
 
243
  input_len = inputs_vlm["input_ids"].shape[-1]
 
244
 
245
  with torch.inference_mode():
246
  generation = self.model.generate(
247
  **inputs_vlm,
248
- max_new_tokens=512,
249
  do_sample=True,
 
250
  temperature=0.7
251
  )
252
  generation = generation[0][input_len:]
253
 
254
- decoded = self.processor.decode(generation, skip_special_tokens=True)
255
- logger.info(f"Chat_v2 response: {decoded}")
256
- return decoded
 
257
 
258
  # TTS Manager
259
  class TTSManager:
@@ -453,7 +492,6 @@ class ChatRequest(BaseModel):
453
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
454
  return v
455
 
456
-
457
  class ChatResponse(BaseModel):
458
  response: str
459
 
 
79
 
80
  # LLM Manager
81
  class LLMManager:
82
+ def __init__(self, model_name: str, device: str = device):
83
  self.model_name = model_name
84
  self.device = torch.device(device)
85
+ self.torch_dtype = torch_dtype
86
  self.model = None
87
  self.processor = None
88
  self.is_loaded = False
89
+ self.token_cache = {}
90
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
91
 
92
  def load(self):
93
  if not self.is_loaded:
94
  try:
95
+ if self.device.type == "cuda":
96
+ torch.set_float32_matmul_precision('high')
97
+ logger.info("Enabled TF32 matrix multiplication for improved performance")
98
+
99
+ quantization_config = BitsAndBytesConfig(
100
+ load_in_4bit=True,
101
+ bnb_4bit_quant_type="nf4",
102
+ bnb_4bit_compute_dtype=self.torch_dtype,
103
+ bnb_4bit_use_double_quant=True
104
+ )
105
+
106
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
107
  self.model_name,
108
  device_map="auto",
109
  quantization_config=quantization_config,
110
+ torch_dtype=self.torch_dtype,
111
+ max_memory={0: "10GiB"}
112
+ ).eval()
113
+
114
+ self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
115
  self.is_loaded = True
116
+ logger.info(f"LLM {self.model_name} loaded on {self.device} with 4-bit quantization and fast processor")
117
  except Exception as e:
118
  logger.error(f"Failed to load LLM: {str(e)}")
119
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
120
 
121
  def unload(self):
122
  if self.is_loaded:
 
126
  torch.cuda.empty_cache()
127
  logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
128
  self.is_loaded = False
129
+ self.token_cache.clear()
130
  logger.info(f"LLM {self.model_name} unloaded from {self.device}")
131
 
132
+ async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
133
  if not self.is_loaded:
134
  self.load()
135
 
136
+ cache_key = f"prompt_{prompt}"
137
+ if cache_key in self.token_cache and "response" in self.token_cache[cache_key]:
138
+ logger.info("Using cached response")
139
+ return self.token_cache[cache_key]["response"]
140
+
141
  messages_vlm = [
142
  {
143
  "role": "system",
 
150
  ]
151
 
152
  try:
153
+ if cache_key in self.token_cache and "inputs" in self.token_cache[cache_key]:
154
+ inputs_vlm = self.token_cache[cache_key]["inputs"]
155
+ logger.info("Using cached tokenized input")
156
+ else:
157
+ inputs_vlm = self.processor.apply_chat_template(
158
+ messages_vlm,
159
+ add_generation_prompt=True,
160
+ tokenize=True,
161
+ return_dict=True,
162
+ return_tensors="pt"
163
+ ).to(self.device, dtype=torch.bfloat16)
164
+ self.token_cache[cache_key] = {"inputs": inputs_vlm}
165
  except Exception as e:
166
  logger.error(f"Error in tokenization: {str(e)}")
167
  raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
168
 
169
  input_len = inputs_vlm["input_ids"].shape[-1]
170
+ adjusted_max_tokens = min(max_tokens, max(20, input_len * 2))
171
 
172
  with torch.inference_mode():
173
  generation = self.model.generate(
174
  **inputs_vlm,
175
+ max_new_tokens=adjusted_max_tokens,
176
  do_sample=True,
177
+ top_p=0.9,
178
  temperature=temperature
179
  )
180
  generation = generation[0][input_len:]
181
 
182
  response = self.processor.decode(generation, skip_special_tokens=True)
183
+ self.token_cache[cache_key]["response"] = response
184
  logger.info(f"Generated response: {response}")
185
  return response
186
 
 
191
  messages_vlm = [
192
  {
193
  "role": "system",
194
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in one sentence maximum."}]
195
  },
196
  {
197
  "role": "user",
198
+ "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])
199
  }
200
  ]
201
 
202
+ cache_key = f"vision_{query}_{'image' if image else 'no_image'}"
203
+ if cache_key in self.token_cache and "response" in self.token_cache[cache_key]:
204
+ logger.info("Using cached response")
205
+ return self.token_cache[cache_key]["response"]
 
 
206
 
207
  try:
208
+ if cache_key in self.token_cache and "inputs" in self.token_cache[cache_key]:
209
+ inputs_vlm = self.token_cache[cache_key]["inputs"]
210
+ logger.info("Using cached tokenized input")
211
+ else:
212
+ inputs_vlm = self.processor.apply_chat_template(
213
+ messages_vlm,
214
+ add_generation_prompt=True,
215
+ tokenize=True,
216
+ return_dict=True,
217
+ return_tensors="pt"
218
+ ).to(self.device, dtype=torch.bfloat16)
219
+ self.token_cache[cache_key] = {"inputs": inputs_vlm}
220
  except Exception as e:
221
  logger.error(f"Error in apply_chat_template: {str(e)}")
222
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
223
 
224
  input_len = inputs_vlm["input_ids"].shape[-1]
225
+ adjusted_max_tokens = min(512, max(20, input_len * 2))
226
 
227
  with torch.inference_mode():
228
  generation = self.model.generate(
229
  **inputs_vlm,
230
+ max_new_tokens=adjusted_max_tokens,
231
  do_sample=True,
232
+ top_p=0.9,
233
  temperature=0.7
234
  )
235
  generation = generation[0][input_len:]
236
 
237
+ response = self.processor.decode(generation, skip_special_tokens=True)
238
+ self.token_cache[cache_key]["response"] = response
239
+ logger.info(f"Vision query response: {response}")
240
+ return response
241
 
242
  async def chat_v2(self, image: Image.Image, query: str) -> str:
243
  if not self.is_loaded:
 
250
  },
251
  {
252
  "role": "user",
253
+ "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])
254
  }
255
  ]
256
 
257
+ cache_key = f"chat_v2_{query}_{'image' if image else 'no_image'}"
258
+ if cache_key in self.token_cache and "response" in self.token_cache[cache_key]:
259
+ logger.info("Using cached response")
260
+ return self.token_cache[cache_key]["response"]
 
 
261
 
262
  try:
263
+ if cache_key in self.token_cache and "inputs" in self.token_cache[cache_key]:
264
+ inputs_vlm = self.token_cache[cache_key]["inputs"]
265
+ logger.info("Using cached tokenized input")
266
+ else:
267
+ inputs_vlm = self.processor.apply_chat_template(
268
+ messages_vlm,
269
+ add_generation_prompt=True,
270
+ tokenize=True,
271
+ return_dict=True,
272
+ return_tensors="pt"
273
+ ).to(self.device, dtype=torch.bfloat16)
274
+ self.token_cache[cache_key] = {"inputs": inputs_vlm}
275
  except Exception as e:
276
  logger.error(f"Error in apply_chat_template: {str(e)}")
277
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
278
 
279
  input_len = inputs_vlm["input_ids"].shape[-1]
280
+ adjusted_max_tokens = min(512, max(20, input_len * 2))
281
 
282
  with torch.inference_mode():
283
  generation = self.model.generate(
284
  **inputs_vlm,
285
+ max_new_tokens=adjusted_max_tokens,
286
  do_sample=True,
287
+ top_p=0.9,
288
  temperature=0.7
289
  )
290
  generation = generation[0][input_len:]
291
 
292
+ response = self.processor.decode(generation, skip_special_tokens=True)
293
+ self.token_cache[cache_key]["response"] = response
294
+ logger.info(f"Chat_v2 response: {response}")
295
+ return response
296
 
297
  # TTS Manager
298
  class TTSManager:
 
492
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
493
  return v
494
 
 
495
  class ChatResponse(BaseModel):
496
  response: str
497
 
src/server/main_v0.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ from time import time
5
+ from typing import List
6
+ import tempfile
7
+ import uvicorn
8
+ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
+ from PIL import Image
12
+ from pydantic import BaseModel, field_validator
13
+ from pydantic_settings import BaseSettings
14
+ from slowapi import Limiter
15
+ from slowapi.util import get_remote_address
16
+ import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
18
+ from IndicTransToolkit import IndicProcessor
19
+ import json
20
+ import asyncio
21
+ from contextlib import asynccontextmanager
22
+ import soundfile as sf
23
+ import numpy as np
24
+ import requests
25
+ from starlette.responses import StreamingResponse
26
+ from logging_config import logger
27
+ from tts_config import SPEED, ResponseFormat, config as tts_config
28
+ import torchaudio
29
+
30
+ # Device setup
31
+ if torch.cuda.is_available():
32
+ device = "cuda:0"
33
+ logger.info("GPU will be used for inference")
34
+ else:
35
+ device = "cpu"
36
+ logger.info("CPU will be used for inference")
37
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
+
39
+ # Check CUDA availability and version
40
+ cuda_available = torch.cuda.is_available()
41
+ cuda_version = torch.version.cuda if cuda_available else None
42
+
43
+ if torch.cuda.is_available():
44
+ device_idx = torch.cuda.current_device()
45
+ capability = torch.cuda.get_device_capability(device_idx)
46
+ compute_capability_float = float(f"{capability[0]}.{capability[1]}")
47
+ print(f"CUDA version: {cuda_version}")
48
+ print(f"CUDA Compute Capability: {compute_capability_float}")
49
+ else:
50
+ print("CUDA is not available on this system.")
51
+
52
+ # Settings
53
+ class Settings(BaseSettings):
54
+ llm_model_name: str = "google/gemma-3-4b-it"
55
+ max_tokens: int = 512
56
+ host: str = "0.0.0.0"
57
+ port: int = 7860
58
+ chat_rate_limit: str = "100/minute"
59
+ speech_rate_limit: str = "5/minute"
60
+
61
+ @field_validator("chat_rate_limit", "speech_rate_limit")
62
+ def validate_rate_limit(cls, v):
63
+ if not v.count("/") == 1 or not v.split("/")[0].isdigit():
64
+ raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
65
+ return v
66
+
67
+ class Config:
68
+ env_file = ".env"
69
+
70
+ settings = Settings()
71
+
72
+ # Quantization config for LLM
73
+ quantization_config = BitsAndBytesConfig(
74
+ load_in_4bit=True,
75
+ bnb_4bit_quant_type="nf4",
76
+ bnb_4bit_use_double_quant=True,
77
+ bnb_4bit_compute_dtype=torch.bfloat16
78
+ )
79
+
80
+ # LLM Manager
81
+ class LLMManager:
82
+ def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
+ self.model_name = model_name
84
+ self.device = torch.device(device)
85
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
+ self.model = None
87
+ self.processor = None
88
+ self.is_loaded = False
89
+ logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
+
91
+ def load(self):
92
+ if not self.is_loaded:
93
+ try:
94
+ self.model = Gemma3ForConditionalGeneration.from_pretrained(
95
+ self.model_name,
96
+ device_map="auto",
97
+ quantization_config=quantization_config,
98
+ torch_dtype=self.torch_dtype
99
+ )
100
+ self.model.eval()
101
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
102
+ self.is_loaded = True
103
+ logger.info(f"LLM {self.model_name} loaded on {self.device}")
104
+ except Exception as e:
105
+ logger.error(f"Failed to load LLM: {str(e)}")
106
+ raise
107
+
108
+ def unload(self):
109
+ if self.is_loaded:
110
+ del self.model
111
+ del self.processor
112
+ if self.device.type == "cuda":
113
+ torch.cuda.empty_cache()
114
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
115
+ self.is_loaded = False
116
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
117
+
118
+ async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
+ if not self.is_loaded:
120
+ self.load()
121
+
122
+ messages_vlm = [
123
+ {
124
+ "role": "system",
125
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
126
+ },
127
+ {
128
+ "role": "user",
129
+ "content": [{"type": "text", "text": prompt}]
130
+ }
131
+ ]
132
+
133
+ try:
134
+ inputs_vlm = self.processor.apply_chat_template(
135
+ messages_vlm,
136
+ add_generation_prompt=True,
137
+ tokenize=True,
138
+ return_dict=True,
139
+ return_tensors="pt"
140
+ ).to(self.device, dtype=torch.bfloat16)
141
+ except Exception as e:
142
+ logger.error(f"Error in tokenization: {str(e)}")
143
+ raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
144
+
145
+ input_len = inputs_vlm["input_ids"].shape[-1]
146
+
147
+ with torch.inference_mode():
148
+ generation = self.model.generate(
149
+ **inputs_vlm,
150
+ max_new_tokens=max_tokens,
151
+ do_sample=True,
152
+ temperature=temperature
153
+ )
154
+ generation = generation[0][input_len:]
155
+
156
+ response = self.processor.decode(generation, skip_special_tokens=True)
157
+ logger.info(f"Generated response: {response}")
158
+ return response
159
+
160
+ async def vision_query(self, image: Image.Image, query: str) -> str:
161
+ if not self.is_loaded:
162
+ self.load()
163
+
164
+ messages_vlm = [
165
+ {
166
+ "role": "system",
167
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
168
+ },
169
+ {
170
+ "role": "user",
171
+ "content": []
172
+ }
173
+ ]
174
+
175
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
176
+ if image and image.size[0] > 0 and image.size[1] > 0:
177
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
178
+ logger.info(f"Received valid image for processing")
179
+ else:
180
+ logger.info("No valid image provided, processing text only")
181
+
182
+ try:
183
+ inputs_vlm = self.processor.apply_chat_template(
184
+ messages_vlm,
185
+ add_generation_prompt=True,
186
+ tokenize=True,
187
+ return_dict=True,
188
+ return_tensors="pt"
189
+ ).to(self.device, dtype=torch.bfloat16)
190
+ except Exception as e:
191
+ logger.error(f"Error in apply_chat_template: {str(e)}")
192
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
193
+
194
+ input_len = inputs_vlm["input_ids"].shape[-1]
195
+
196
+ with torch.inference_mode():
197
+ generation = self.model.generate(
198
+ **inputs_vlm,
199
+ max_new_tokens=512,
200
+ do_sample=True,
201
+ temperature=0.7
202
+ )
203
+ generation = generation[0][input_len:]
204
+
205
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
206
+ logger.info(f"Vision query response: {decoded}")
207
+ return decoded
208
+
209
+ async def chat_v2(self, image: Image.Image, query: str) -> str:
210
+ if not self.is_loaded:
211
+ self.load()
212
+
213
+ messages_vlm = [
214
+ {
215
+ "role": "system",
216
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
217
+ },
218
+ {
219
+ "role": "user",
220
+ "content": []
221
+ }
222
+ ]
223
+
224
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
225
+ if image and image.size[0] > 0 and image.size[1] > 0:
226
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
227
+ logger.info(f"Received valid image for processing")
228
+ else:
229
+ logger.info("No valid image provided, processing text only")
230
+
231
+ try:
232
+ inputs_vlm = self.processor.apply_chat_template(
233
+ messages_vlm,
234
+ add_generation_prompt=True,
235
+ tokenize=True,
236
+ return_dict=True,
237
+ return_tensors="pt"
238
+ ).to(self.device, dtype=torch.bfloat16)
239
+ except Exception as e:
240
+ logger.error(f"Error in apply_chat_template: {str(e)}")
241
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
242
+
243
+ input_len = inputs_vlm["input_ids"].shape[-1]
244
+
245
+ with torch.inference_mode():
246
+ generation = self.model.generate(
247
+ **inputs_vlm,
248
+ max_new_tokens=512,
249
+ do_sample=True,
250
+ temperature=0.7
251
+ )
252
+ generation = generation[0][input_len:]
253
+
254
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
255
+ logger.info(f"Chat_v2 response: {decoded}")
256
+ return decoded
257
+
258
+ # TTS Manager
259
+ class TTSManager:
260
+ def __init__(self, device_type=device):
261
+ self.device_type = device_type
262
+ self.model = None
263
+ self.repo_id = "ai4bharat/IndicF5"
264
+
265
+ def load(self):
266
+ if not self.model:
267
+ logger.info("Loading TTS model IndicF5...")
268
+ self.model = AutoModel.from_pretrained(
269
+ self.repo_id,
270
+ trust_remote_code=True
271
+ )
272
+ self.model = self.model.to(self.device_type)
273
+ logger.info("TTS model IndicF5 loaded")
274
+
275
+ def synthesize(self, text, ref_audio_path, ref_text):
276
+ if not self.model:
277
+ raise ValueError("TTS model not loaded")
278
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
279
+
280
+ # TTS Constants
281
+ EXAMPLES = [
282
+ {
283
+ "audio_name": "KAN_F (Happy)",
284
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
285
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
286
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
287
+ },
288
+ ]
289
+
290
+ # Pydantic models for TTS
291
+ class SynthesizeRequest(BaseModel):
292
+ text: str
293
+ ref_audio_name: str
294
+ ref_text: str = None
295
+
296
+ class KannadaSynthesizeRequest(BaseModel):
297
+ text: str
298
+
299
+ # TTS Functions
300
+ def load_audio_from_url(url: str):
301
+ response = requests.get(url)
302
+ if response.status_code == 200:
303
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
304
+ return sample_rate, audio_data
305
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
306
+
307
+ def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
308
+ ref_audio_url = None
309
+ for example in EXAMPLES:
310
+ if example["audio_name"] == ref_audio_name:
311
+ ref_audio_url = example["audio_url"]
312
+ if not ref_text:
313
+ ref_text = example["ref_text"]
314
+ break
315
+
316
+ if not ref_audio_url:
317
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
318
+ if not text.strip():
319
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
320
+ if not ref_text or not ref_text.strip():
321
+ raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
322
+
323
+ sample_rate, audio_data = load_audio_from_url(ref_audio_url)
324
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
325
+ sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
326
+ temp_audio.flush()
327
+ audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
328
+
329
+ if audio.dtype == np.int16:
330
+ audio = audio.astype(np.float32) / 32768.0
331
+ buffer = io.BytesIO()
332
+ sf.write(buffer, audio, 24000, format='WAV')
333
+ buffer.seek(0)
334
+ return buffer
335
+
336
+ # Supported languages
337
+ SUPPORTED_LANGUAGES = {
338
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
339
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
340
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
341
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
342
+ "kan_Knda", "ory_Orya",
343
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
344
+ "por_Latn", "rus_Cyrl", "pol_Latn"
345
+ }
346
+
347
+ # Translation Manager
348
+ class TranslateManager:
349
+ def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
350
+ self.device_type = device_type
351
+ self.tokenizer = None
352
+ self.model = None
353
+ self.src_lang = src_lang
354
+ self.tgt_lang = tgt_lang
355
+ self.use_distilled = use_distilled
356
+
357
+ def load(self):
358
+ if not self.tokenizer or not self.model:
359
+ if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
360
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
361
+ elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
362
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
363
+ elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
364
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
365
+ else:
366
+ raise ValueError("Invalid language combination")
367
+
368
+ self.tokenizer = AutoTokenizer.from_pretrained(
369
+ model_name,
370
+ trust_remote_code=True
371
+ )
372
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
373
+ model_name,
374
+ trust_remote_code=True,
375
+ torch_dtype=torch.float16,
376
+ attn_implementation="flash_attention_2"
377
+ )
378
+ self.model = self.model.to(self.device_type)
379
+ self.model = torch.compile(self.model, mode="reduce-overhead")
380
+ logger.info(f"Translation model {model_name} loaded")
381
+
382
+ class ModelManager:
383
+ def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
384
+ self.models = {}
385
+ self.device_type = device_type
386
+ self.use_distilled = use_distilled
387
+ self.is_lazy_loading = is_lazy_loading
388
+
389
+ def load_model(self, src_lang, tgt_lang, key):
390
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
391
+ translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
392
+ translate_manager.load()
393
+ self.models[key] = translate_manager
394
+ logger.info(f"Loaded translation model for {key}")
395
+
396
+ def get_model(self, src_lang, tgt_lang):
397
+ key = self._get_model_key(src_lang, tgt_lang)
398
+ if key not in self.models:
399
+ if self.is_lazy_loading:
400
+ self.load_model(src_lang, tgt_lang, key)
401
+ else:
402
+ raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
403
+ return self.models.get(key)
404
+
405
+ def _get_model_key(self, src_lang, tgt_lang):
406
+ if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
407
+ return 'eng_indic'
408
+ elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
409
+ return 'indic_eng'
410
+ elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
411
+ return 'indic_indic'
412
+ raise ValueError("Invalid language combination")
413
+
414
+ # ASR Manager
415
+ class ASRModelManager:
416
+ def __init__(self, device_type="cuda"):
417
+ self.device_type = device_type
418
+ self.model = None
419
+ self.model_language = {"kannada": "kn"}
420
+
421
+ def load(self):
422
+ if not self.model:
423
+ logger.info("Loading ASR model...")
424
+ self.model = AutoModel.from_pretrained(
425
+ "ai4bharat/indic-conformer-600m-multilingual",
426
+ trust_remote_code=True
427
+ )
428
+ self.model = self.model.to(self.device_type)
429
+ logger.info("ASR model loaded")
430
+
431
+ # Global Managers
432
+ llm_manager = LLMManager(settings.llm_model_name)
433
+ model_manager = ModelManager()
434
+ asr_manager = ASRModelManager()
435
+ tts_manager = TTSManager()
436
+ ip = IndicProcessor(inference=True)
437
+
438
+ # Pydantic Models
439
+ class ChatRequest(BaseModel):
440
+ prompt: str
441
+ src_lang: str = "kan_Knda"
442
+ tgt_lang: str = "kan_Knda"
443
+
444
+ @field_validator("prompt")
445
+ def prompt_must_be_valid(cls, v):
446
+ if len(v) > 1000:
447
+ raise ValueError("Prompt cannot exceed 1000 characters")
448
+ return v.strip()
449
+
450
+ @field_validator("src_lang", "tgt_lang")
451
+ def validate_language(cls, v):
452
+ if v not in SUPPORTED_LANGUAGES:
453
+ raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
454
+ return v
455
+
456
+
457
+ class ChatResponse(BaseModel):
458
+ response: str
459
+
460
+ class TranslationRequest(BaseModel):
461
+ sentences: List[str]
462
+ src_lang: str
463
+ tgt_lang: str
464
+
465
+ class TranscriptionResponse(BaseModel):
466
+ text: str
467
+
468
+ class TranslationResponse(BaseModel):
469
+ translations: List[str]
470
+
471
+ # Dependency
472
+ def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
473
+ return model_manager.get_model(src_lang, tgt_lang)
474
+
475
+ # Lifespan Event Handler
476
+ translation_configs = []
477
+
478
+ @asynccontextmanager
479
+ async def lifespan(app: FastAPI):
480
+ def load_all_models():
481
+ try:
482
+ # Load LLM model
483
+ logger.info("Loading LLM model...")
484
+ llm_manager.load()
485
+ logger.info("LLM model loaded successfully")
486
+
487
+ # Load TTS model
488
+ logger.info("Loading TTS model...")
489
+ tts_manager.load()
490
+ logger.info("TTS model loaded successfully")
491
+
492
+ # Load ASR model
493
+ logger.info("Loading ASR model...")
494
+ asr_manager.load()
495
+ logger.info("ASR model loaded successfully")
496
+
497
+ # Load translation models
498
+ translation_tasks = [
499
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
500
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
501
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
502
+ ]
503
+
504
+ for config in translation_configs:
505
+ src_lang = config["src_lang"]
506
+ tgt_lang = config["tgt_lang"]
507
+ key = model_manager._get_model_key(src_lang, tgt_lang)
508
+ translation_tasks.append((src_lang, tgt_lang, key))
509
+
510
+ for src_lang, tgt_lang, key in translation_tasks:
511
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
512
+ model_manager.load_model(src_lang, tgt_lang, key)
513
+ logger.info(f"Translation model for {key} loaded successfully")
514
+
515
+ logger.info("All models loaded successfully")
516
+ except Exception as e:
517
+ logger.error(f"Error loading models: {str(e)}")
518
+ raise
519
+
520
+ logger.info("Starting sequential model loading...")
521
+ load_all_models()
522
+ yield
523
+ llm_manager.unload()
524
+ logger.info("Server shutdown complete")
525
+
526
+ # FastAPI App
527
+ app = FastAPI(
528
+ title="Dhwani API",
529
+ description="AI Chat API supporting Indian languages",
530
+ version="1.0.0",
531
+ redirect_slashes=False,
532
+ lifespan=lifespan
533
+ )
534
+
535
+ # Add CORS Middleware
536
+ app.add_middleware(
537
+ CORSMiddleware,
538
+ allow_origins=["*"],
539
+ allow_credentials=False,
540
+ allow_methods=["*"],
541
+ allow_headers=["*"],
542
+ )
543
+
544
+ # Add Timing Middleware
545
+ @app.middleware("http")
546
+ async def add_request_timing(request: Request, call_next):
547
+ start_time = time()
548
+ response = await call_next(request)
549
+ end_time = time()
550
+ duration = end_time - start_time
551
+ logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
552
+ response.headers["X-Response-Time"] = f"{duration:.3f}"
553
+ return response
554
+
555
+ limiter = Limiter(key_func=get_remote_address)
556
+ app.state.limiter = limiter
557
+
558
+ # API Endpoints
559
+ @app.post("/audio/speech", response_class=StreamingResponse)
560
+ async def synthesize_kannada(request: KannadaSynthesizeRequest):
561
+ if not tts_manager.model:
562
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
563
+ kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
564
+ if not request.text.strip():
565
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
566
+
567
+ audio_buffer = synthesize_speech(
568
+ tts_manager,
569
+ text=request.text,
570
+ ref_audio_name="KAN_F (Happy)",
571
+ ref_text=kannada_example["ref_text"]
572
+ )
573
+
574
+ return StreamingResponse(
575
+ audio_buffer,
576
+ media_type="audio/wav",
577
+ headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
578
+ )
579
+
580
+ @app.post("/translate", response_model=TranslationResponse)
581
+ async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
582
+ input_sentences = request.sentences
583
+ src_lang = request.src_lang
584
+ tgt_lang = request.tgt_lang
585
+
586
+ if not input_sentences:
587
+ raise HTTPException(status_code=400, detail="Input sentences are required")
588
+
589
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
590
+ inputs = translate_manager.tokenizer(
591
+ batch,
592
+ truncation=True,
593
+ padding="longest",
594
+ return_tensors="pt",
595
+ return_attention_mask=True,
596
+ ).to(translate_manager.device_type)
597
+
598
+ with torch.no_grad():
599
+ generated_tokens = translate_manager.model.generate(
600
+ **inputs,
601
+ use_cache=True,
602
+ min_length=0,
603
+ max_length=256,
604
+ num_beams=5,
605
+ num_return_sequences=1,
606
+ )
607
+
608
+ with translate_manager.tokenizer.as_target_tokenizer():
609
+ generated_tokens = translate_manager.tokenizer.batch_decode(
610
+ generated_tokens.detach().cpu().tolist(),
611
+ skip_special_tokens=True,
612
+ clean_up_tokenization_spaces=True,
613
+ )
614
+
615
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
616
+ return TranslationResponse(translations=translations)
617
+
618
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
619
+ try:
620
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
621
+ except ValueError as e:
622
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
623
+ key = model_manager._get_model_key(src_lang, tgt_lang)
624
+ model_manager.load_model(src_lang, tgt_lang, key)
625
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
626
+
627
+ if not translate_manager.model:
628
+ translate_manager.load()
629
+
630
+ request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
631
+ response = await translate(request, translate_manager)
632
+ return response.translations
633
+
634
+ @app.get("/v1/health")
635
+ async def health_check():
636
+ return {"status": "healthy", "model": settings.llm_model_name}
637
+
638
+ @app.get("/")
639
+ async def home():
640
+ return RedirectResponse(url="/docs")
641
+
642
+ @app.post("/v1/unload_all_models")
643
+ async def unload_all_models():
644
+ try:
645
+ logger.info("Starting to unload all models...")
646
+ llm_manager.unload()
647
+ logger.info("All models unloaded successfully")
648
+ return {"status": "success", "message": "All models unloaded"}
649
+ except Exception as e:
650
+ logger.error(f"Error unloading models: {str(e)}")
651
+ raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
652
+
653
+ @app.post("/v1/load_all_models")
654
+ async def load_all_models():
655
+ try:
656
+ logger.info("Starting to load all models...")
657
+ llm_manager.load()
658
+ logger.info("All models loaded successfully")
659
+ return {"status": "success", "message": "All models loaded"}
660
+ except Exception as e:
661
+ logger.error(f"Error loading models: {str(e)}")
662
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
663
+
664
+ @app.post("/v1/translate", response_model=TranslationResponse)
665
+ async def translate_endpoint(request: TranslationRequest):
666
+ logger.info(f"Received translation request: {request.dict()}")
667
+ try:
668
+ translations = await perform_internal_translation(
669
+ sentences=request.sentences,
670
+ src_lang=request.src_lang,
671
+ tgt_lang=request.tgt_lang
672
+ )
673
+ logger.info(f"Translation successful: {translations}")
674
+ return TranslationResponse(translations=translations)
675
+ except Exception as e:
676
+ logger.error(f"Unexpected error during translation: {str(e)}")
677
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
678
+
679
+ @app.post("/v1/chat", response_model=ChatResponse)
680
+ @limiter.limit(settings.chat_rate_limit)
681
+ async def chat(request: Request, chat_request: ChatRequest):
682
+ if not chat_request.prompt:
683
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
684
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
685
+
686
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
687
+
688
+ try:
689
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
690
+ translated_prompt = await perform_internal_translation(
691
+ sentences=[chat_request.prompt],
692
+ src_lang=chat_request.src_lang,
693
+ tgt_lang="eng_Latn"
694
+ )
695
+ prompt_to_process = translated_prompt[0]
696
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
697
+ else:
698
+ prompt_to_process = chat_request.prompt
699
+ logger.info("Prompt in English or European language, no translation needed")
700
+
701
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
702
+ logger.info(f"Generated response: {response}")
703
+
704
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
705
+ translated_response = await perform_internal_translation(
706
+ sentences=[response],
707
+ src_lang="eng_Latn",
708
+ tgt_lang=chat_request.tgt_lang
709
+ )
710
+ final_response = translated_response[0]
711
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
712
+ else:
713
+ final_response = response
714
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
715
+
716
+ return ChatResponse(response=final_response)
717
+ except Exception as e:
718
+ logger.error(f"Error processing request: {str(e)}")
719
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
720
+
721
+ @app.post("/v1/visual_query/")
722
+ async def visual_query(
723
+ file: UploadFile = File(...),
724
+ query: str = Body(...),
725
+ src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
726
+ tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
727
+ ):
728
+ try:
729
+ image = Image.open(file.file)
730
+ if image.size == (0, 0):
731
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
732
+
733
+ if src_lang != "eng_Latn":
734
+ translated_query = await perform_internal_translation(
735
+ sentences=[query],
736
+ src_lang=src_lang,
737
+ tgt_lang="eng_Latn"
738
+ )
739
+ query_to_process = translated_query[0]
740
+ logger.info(f"Translated query to English: {query_to_process}")
741
+ else:
742
+ query_to_process = query
743
+ logger.info("Query already in English, no translation needed")
744
+
745
+ answer = await llm_manager.vision_query(image, query_to_process)
746
+ logger.info(f"Generated English answer: {answer}")
747
+
748
+ if tgt_lang != "eng_Latn":
749
+ translated_answer = await perform_internal_translation(
750
+ sentences=[answer],
751
+ src_lang="eng_Latn",
752
+ tgt_lang=tgt_lang
753
+ )
754
+ final_answer = translated_answer[0]
755
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
756
+ else:
757
+ final_answer = answer
758
+ logger.info("Answer kept in English, no translation needed")
759
+
760
+ return {"answer": final_answer}
761
+ except Exception as e:
762
+ logger.error(f"Error processing request: {str(e)}")
763
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
764
+
765
+ @app.post("/v1/chat_v2", response_model=ChatResponse)
766
+ @limiter.limit(settings.chat_rate_limit)
767
+ async def chat_v2(
768
+ request: Request,
769
+ prompt: str = Form(...),
770
+ image: UploadFile = File(default=None),
771
+ src_lang: str = Form("kan_Knda"),
772
+ tgt_lang: str = Form("kan_Knda"),
773
+ ):
774
+ if not prompt:
775
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
776
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
777
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
778
+
779
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
780
+
781
+ try:
782
+ if image:
783
+ image_data = await image.read()
784
+ if not image_data:
785
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
786
+ img = Image.open(io.BytesIO(image_data))
787
+
788
+ if src_lang != "eng_Latn":
789
+ translated_prompt = await perform_internal_translation(
790
+ sentences=[prompt],
791
+ src_lang=src_lang,
792
+ tgt_lang="eng_Latn"
793
+ )
794
+ prompt_to_process = translated_prompt[0]
795
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
796
+ else:
797
+ prompt_to_process = prompt
798
+ logger.info("Prompt already in English, no translation needed")
799
+
800
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
801
+ logger.info(f"Generated English response: {decoded}")
802
+
803
+ if tgt_lang != "eng_Latn":
804
+ translated_response = await perform_internal_translation(
805
+ sentences=[decoded],
806
+ src_lang="eng_Latn",
807
+ tgt_lang=tgt_lang
808
+ )
809
+ final_response = translated_response[0]
810
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
811
+ else:
812
+ final_response = decoded
813
+ logger.info("Response kept in English, no translation needed")
814
+ else:
815
+ if src_lang != "eng_Latn":
816
+ translated_prompt = await perform_internal_translation(
817
+ sentences=[prompt],
818
+ src_lang=src_lang,
819
+ tgt_lang="eng_Latn"
820
+ )
821
+ prompt_to_process = translated_prompt[0]
822
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
823
+ else:
824
+ prompt_to_process = prompt
825
+ logger.info("Prompt already in English, no translation needed")
826
+
827
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
828
+ logger.info(f"Generated English response: {decoded}")
829
+
830
+ if tgt_lang != "eng_Latn":
831
+ translated_response = await perform_internal_translation(
832
+ sentences=[decoded],
833
+ src_lang="eng_Latn",
834
+ tgt_lang=tgt_lang
835
+ )
836
+ final_response = translated_response[0]
837
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
838
+ else:
839
+ final_response = decoded
840
+ logger.info("Response kept in English, no translation needed")
841
+
842
+ return ChatResponse(response=final_response)
843
+ except Exception as e:
844
+ logger.error(f"Error processing request: {str(e)}")
845
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
846
+
847
+ @app.post("/transcribe/", response_model=TranscriptionResponse)
848
+ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
849
+ if not asr_manager.model:
850
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
851
+ try:
852
+ wav, sr = torchaudio.load(file.file)
853
+ wav = torch.mean(wav, dim=0, keepdim=True)
854
+ target_sample_rate = 16000
855
+ if sr != target_sample_rate:
856
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
857
+ wav = resampler(wav)
858
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
859
+ return TranscriptionResponse(text=transcription_rnnt)
860
+ except Exception as e:
861
+ logger.error(f"Error in transcription: {str(e)}")
862
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
863
+
864
+ @app.post("/v1/speech_to_speech")
865
+ async def speech_to_speech(
866
+ request: Request,
867
+ file: UploadFile = File(...),
868
+ language: str = Query(..., enum=list(asr_manager.model_language.keys())),
869
+ ) -> StreamingResponse:
870
+ if not tts_manager.model:
871
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
872
+ transcription = await transcribe_audio(file, language)
873
+ logger.info(f"Transcribed text: {transcription.text}")
874
+
875
+ chat_request = ChatRequest(
876
+ prompt=transcription.text,
877
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
878
+ tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
879
+ )
880
+ processed_text = await chat(request, chat_request)
881
+ logger.info(f"Processed text: {processed_text.response}")
882
+
883
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
884
+ audio_response = await synthesize_kannada(voice_request)
885
+ return audio_response
886
+
887
+ LANGUAGE_TO_SCRIPT = {
888
+ "kannada": "kan_Knda"
889
+ }
890
+
891
+ # Main Execution
892
+ if __name__ == "__main__":
893
+ parser = argparse.ArgumentParser(description="Run the FastAPI server.")
894
+ parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
895
+ parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
896
+ parser.add_argument("--config", type=str, default="config_one", help="Configuration to use")
897
+ args = parser.parse_args()
898
+
899
+ def load_config(config_path="dhwani_config.json"):
900
+ with open(config_path, "r") as f:
901
+ return json.load(f)
902
+
903
+ config_data = load_config()
904
+ if args.config not in config_data["configs"]:
905
+ raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
906
+
907
+ selected_config = config_data["configs"][args.config]
908
+ global_settings = config_data["global_settings"]
909
+
910
+ settings.llm_model_name = selected_config["components"]["LLM"]["model"]
911
+ settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
912
+ settings.host = global_settings["host"]
913
+ settings.port = global_settings["port"]
914
+ settings.chat_rate_limit = global_settings["chat_rate_limit"]
915
+ settings.speech_rate_limit = global_settings["speech_rate_limit"]
916
+
917
+ llm_manager = LLMManager(settings.llm_model_name)
918
+
919
+ if selected_config["components"]["ASR"]:
920
+ asr_model_name = selected_config["components"]["ASR"]["model"]
921
+ asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
922
+
923
+ if selected_config["components"]["Translation"]:
924
+ translation_configs.extend(selected_config["components"]["Translation"])
925
+
926
+ host = args.host if args.host != settings.host else settings.host
927
+ port = args.port if args.port != settings.port else settings.port
928
+
929
+ uvicorn.run(app, host=host, port=port)