sachin commited on
Commit
773ab72
·
1 Parent(s): 0a0efec
Files changed (1) hide show
  1. src/server/main.py +215 -7
src/server/main.py CHANGED
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
@@ -25,7 +25,6 @@ import requests
25
  from starlette.responses import StreamingResponse
26
  from logging_config import logger
27
  from tts_config import SPEED, ResponseFormat, config as tts_config
28
- from gemma_llm import LLMManager # Assuming this is your custom LLMManager
29
 
30
  # Device setup
31
  if torch.cuda.is_available():
@@ -69,6 +68,209 @@ class Settings(BaseSettings):
69
 
70
  settings = Settings()
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # TTS Manager
73
  class TTSManager:
74
  def __init__(self, device_type=device):
@@ -197,7 +399,7 @@ class ModelManager:
197
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
198
  model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
199
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
200
- model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
201
  else:
202
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
203
 
@@ -292,6 +494,8 @@ def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
292
  return model_manager.get_model(src_lang, tgt_lang)
293
 
294
  # Lifespan Event Handler
 
 
295
  @asynccontextmanager
296
  async def lifespan(app: FastAPI):
297
  async def load_all_models():
@@ -303,6 +507,12 @@ async def lifespan(app: FastAPI):
303
  asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
304
  asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
305
  ]
 
 
 
 
 
 
306
  await asyncio.gather(*tasks)
307
  logger.info("All models loaded successfully")
308
 
@@ -616,6 +826,7 @@ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(.
616
  if not asr_manager.model:
617
  raise HTTPException(status_code=503, detail="ASR model still loading, please try again later")
618
  try:
 
619
  wav, sr = torchaudio.load(file.file)
620
  wav = torch.mean(wav, dim=0, keepdim=True)
621
  target_sample_rate = 16000
@@ -688,10 +899,7 @@ if __name__ == "__main__":
688
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
689
 
690
  if selected_config["components"]["Translation"]:
691
- for translation_config in selected_config["components"]["Translation"]:
692
- src_lang = translation_config["src_lang"]
693
- tgt_lang = translation_config["tgt_lang"]
694
- asyncio.create_task(model_manager.load_model(src_lang, tgt_lang, model_manager._get_model_key(src_lang, tgt_lang)))
695
 
696
  host = args.host if args.host != settings.host else settings.host
697
  port = args.port if args.port != settings.port else settings.port
 
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel, AutoProcessor, BitsAndBytesConfig
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
 
25
  from starlette.responses import StreamingResponse
26
  from logging_config import logger
27
  from tts_config import SPEED, ResponseFormat, config as tts_config
 
28
 
29
  # Device setup
30
  if torch.cuda.is_available():
 
68
 
69
  settings = Settings()
70
 
71
+ # Quantization config for LLM
72
+ quantization_config = BitsAndBytesConfig(
73
+ load_in_4bit=True,
74
+ bnb_4bit_quant_type="nf4",
75
+ bnb_4bit_use_double_quant=True,
76
+ bnb_4bit_compute_dtype=torch.bfloat16
77
+ )
78
+
79
+ # LLM Manager (adapted from gemma_llm.py)
80
+ class LLMManager:
81
+ def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
82
+ self.model_name = model_name
83
+ self.device = torch.device(device)
84
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
85
+ self.model = None
86
+ self.is_loaded = False
87
+ self.processor = None
88
+ logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
89
+
90
+ async def unload(self):
91
+ if self.is_loaded:
92
+ await asyncio.to_thread(self._unload_sync)
93
+ self.is_loaded = False
94
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
95
+
96
+ def _unload_sync(self):
97
+ del self.model
98
+ del self.processor
99
+ if self.device.type == "cuda":
100
+ torch.cuda.empty_cache()
101
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
102
+
103
+ async def load(self):
104
+ if not self.is_loaded:
105
+ try:
106
+ self.model = await asyncio.to_thread(
107
+ AutoModel.from_pretrained,
108
+ self.model_name,
109
+ device_map="auto",
110
+ quantization_config=quantization_config,
111
+ torch_dtype=self.torch_dtype
112
+ )
113
+ self.model.eval()
114
+ self.processor = await asyncio.to_thread(AutoProcessor.from_pretrained, self.model_name)
115
+ self.is_loaded = True
116
+ logger.info(f"LLM {self.model_name} loaded on {self.device} with 4-bit quantization")
117
+ except Exception as e:
118
+ logger.error(f"Failed to load model: {str(e)}")
119
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
120
+
121
+ async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
122
+ if not self.is_loaded:
123
+ await self.load()
124
+
125
+ messages_vlm = [
126
+ {
127
+ "role": "system",
128
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
129
+ },
130
+ {
131
+ "role": "user",
132
+ "content": [{"type": "text", "text": prompt}]
133
+ }
134
+ ]
135
+
136
+ try:
137
+ inputs_vlm = await asyncio.to_thread(
138
+ self.processor.apply_chat_template,
139
+ messages_vlm,
140
+ add_generation_prompt=True,
141
+ tokenize=True,
142
+ return_dict=True,
143
+ return_tensors="pt"
144
+ )
145
+ inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
146
+ logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
147
+ logger.info(f"Decoded input: {self.processor.decode(inputs_vlm['input_ids'][0])}")
148
+ except Exception as e:
149
+ logger.error(f"Error in tokenization: {str(e)}")
150
+ raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
151
+
152
+ input_len = inputs_vlm["input_ids"].shape[-1]
153
+
154
+ with torch.inference_mode():
155
+ generation = await asyncio.to_thread(
156
+ self.model.generate,
157
+ **inputs_vlm,
158
+ max_new_tokens=max_tokens,
159
+ do_sample=True,
160
+ temperature=temperature
161
+ )
162
+ generation = generation[0][input_len:]
163
+
164
+ response = self.processor.decode(generation, skip_special_tokens=True)
165
+ logger.info(f"Generated response: {response}")
166
+ return response
167
+
168
+ async def vision_query(self, image: Image.Image, query: str) -> str:
169
+ if not self.is_loaded:
170
+ await self.load()
171
+
172
+ messages_vlm = [
173
+ {
174
+ "role": "system",
175
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
176
+ },
177
+ {
178
+ "role": "user",
179
+ "content": []
180
+ }
181
+ ]
182
+
183
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
184
+ if image and image.size[0] > 0 and image.size[1] > 0:
185
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
186
+ logger.info(f"Received valid image for processing")
187
+ else:
188
+ logger.info("No valid image provided, processing text only")
189
+
190
+ try:
191
+ inputs_vlm = await asyncio.to_thread(
192
+ self.processor.apply_chat_template,
193
+ messages_vlm,
194
+ add_generation_prompt=True,
195
+ tokenize=True,
196
+ return_dict=True,
197
+ return_tensors="pt"
198
+ )
199
+ inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
200
+ logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
201
+ except Exception as e:
202
+ logger.error(f"Error in apply_chat_template: {str(e)}")
203
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
204
+
205
+ input_len = inputs_vlm["input_ids"].shape[-1]
206
+
207
+ with torch.inference_mode():
208
+ generation = await asyncio.to_thread(
209
+ self.model.generate,
210
+ **inputs_vlm,
211
+ max_new_tokens=512,
212
+ do_sample=True,
213
+ temperature=0.7
214
+ )
215
+ generation = generation[0][input_len:]
216
+
217
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
218
+ logger.info(f"Vision query response: {decoded}")
219
+ return decoded
220
+
221
+ async def chat_v2(self, image: Image.Image, query: str) -> str:
222
+ if not self.is_loaded:
223
+ await self.load()
224
+
225
+ messages_vlm = [
226
+ {
227
+ "role": "system",
228
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
229
+ },
230
+ {
231
+ "role": "user",
232
+ "content": []
233
+ }
234
+ ]
235
+
236
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
237
+ if image and image.size[0] > 0 and image.size[1] > 0:
238
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
239
+ logger.info(f"Received valid image for processing")
240
+ else:
241
+ logger.info("No valid image provided, processing text only")
242
+
243
+ try:
244
+ inputs_vlm = await asyncio.to_thread(
245
+ self.processor.apply_chat_template,
246
+ messages_vlm,
247
+ add_generation_prompt=True,
248
+ tokenize=True,
249
+ return_dict=True,
250
+ return_tensors="pt"
251
+ )
252
+ inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
253
+ logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
254
+ except Exception as e:
255
+ logger.error(f"Error in apply_chat_template: {str(e)}")
256
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
257
+
258
+ input_len = inputs_vlm["input_ids"].shape[-1]
259
+
260
+ with torch.inference_mode():
261
+ generation = await asyncio.to_thread(
262
+ self.model.generate,
263
+ **inputs_vlm,
264
+ max_new_tokens=512,
265
+ do_sample=True,
266
+ temperature=0.7
267
+ )
268
+ generation = generation[0][input_len:]
269
+
270
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
271
+ logger.info(f"Chat_v2 response: {decoded}")
272
+ return decoded
273
+
274
  # TTS Manager
275
  class TTSManager:
276
  def __init__(self, device_type=device):
 
399
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
400
  model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
401
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
402
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
403
  else:
404
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
405
 
 
494
  return model_manager.get_model(src_lang, tgt_lang)
495
 
496
  # Lifespan Event Handler
497
+ translation_configs = []
498
+
499
  @asynccontextmanager
500
  async def lifespan(app: FastAPI):
501
  async def load_all_models():
 
507
  asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
508
  asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
509
  ]
510
+ for config in translation_configs:
511
+ src_lang = config["src_lang"]
512
+ tgt_lang = config["tgt_lang"]
513
+ key = model_manager._get_model_key(src_lang, tgt_lang)
514
+ tasks.append(asyncio.create_task(model_manager.load_model(src_lang, tgt_lang, key)))
515
+
516
  await asyncio.gather(*tasks)
517
  logger.info("All models loaded successfully")
518
 
 
826
  if not asr_manager.model:
827
  raise HTTPException(status_code=503, detail="ASR model still loading, please try again later")
828
  try:
829
+ import torchaudio # Added here for clarity
830
  wav, sr = torchaudio.load(file.file)
831
  wav = torch.mean(wav, dim=0, keepdim=True)
832
  target_sample_rate = 16000
 
899
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
900
 
901
  if selected_config["components"]["Translation"]:
902
+ translation_configs.extend(selected_config["components"]["Translation"])
 
 
 
903
 
904
  host = args.host if args.host != settings.host else settings.host
905
  port = args.port if args.port != settings.port else settings.port