sachin commited on
Commit
2472b8d
·
1 Parent(s): 773ab72
Files changed (1) hide show
  1. src/server/main.py +66 -85
src/server/main.py CHANGED
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel, AutoProcessor, BitsAndBytesConfig
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
@@ -68,7 +68,7 @@ class Settings(BaseSettings):
68
 
69
  settings = Settings()
70
 
71
- # Quantization config for LLM
72
  quantization_config = BitsAndBytesConfig(
73
  load_in_4bit=True,
74
  bnb_4bit_quant_type="nf4",
@@ -76,7 +76,7 @@ quantization_config = BitsAndBytesConfig(
76
  bnb_4bit_compute_dtype=torch.bfloat16
77
  )
78
 
79
- # LLM Manager (adapted from gemma_llm.py)
80
  class LLMManager:
81
  def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
82
  self.model_name = model_name
@@ -87,24 +87,11 @@ class LLMManager:
87
  self.processor = None
88
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
89
 
90
- async def unload(self):
91
- if self.is_loaded:
92
- await asyncio.to_thread(self._unload_sync)
93
- self.is_loaded = False
94
- logger.info(f"LLM {self.model_name} unloaded from {self.device}")
95
-
96
- def _unload_sync(self):
97
- del self.model
98
- del self.processor
99
- if self.device.type == "cuda":
100
- torch.cuda.empty_cache()
101
- logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
102
-
103
  async def load(self):
104
  if not self.is_loaded:
105
  try:
106
  self.model = await asyncio.to_thread(
107
- AutoModel.from_pretrained,
108
  self.model_name,
109
  device_map="auto",
110
  quantization_config=quantization_config,
@@ -118,6 +105,16 @@ class LLMManager:
118
  logger.error(f"Failed to load model: {str(e)}")
119
  raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
120
 
 
 
 
 
 
 
 
 
 
 
121
  async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
122
  if not self.is_loaded:
123
  await self.load()
@@ -134,15 +131,13 @@ class LLMManager:
134
  ]
135
 
136
  try:
137
- inputs_vlm = await asyncio.to_thread(
138
- self.processor.apply_chat_template,
139
  messages_vlm,
140
  add_generation_prompt=True,
141
  tokenize=True,
142
  return_dict=True,
143
  return_tensors="pt"
144
- )
145
- inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
146
  logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
147
  logger.info(f"Decoded input: {self.processor.decode(inputs_vlm['input_ids'][0])}")
148
  except Exception as e:
@@ -152,8 +147,7 @@ class LLMManager:
152
  input_len = inputs_vlm["input_ids"].shape[-1]
153
 
154
  with torch.inference_mode():
155
- generation = await asyncio.to_thread(
156
- self.model.generate,
157
  **inputs_vlm,
158
  max_new_tokens=max_tokens,
159
  do_sample=True,
@@ -188,15 +182,13 @@ class LLMManager:
188
  logger.info("No valid image provided, processing text only")
189
 
190
  try:
191
- inputs_vlm = await asyncio.to_thread(
192
- self.processor.apply_chat_template,
193
  messages_vlm,
194
  add_generation_prompt=True,
195
  tokenize=True,
196
  return_dict=True,
197
  return_tensors="pt"
198
- )
199
- inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
200
  logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
201
  except Exception as e:
202
  logger.error(f"Error in apply_chat_template: {str(e)}")
@@ -205,8 +197,7 @@ class LLMManager:
205
  input_len = inputs_vlm["input_ids"].shape[-1]
206
 
207
  with torch.inference_mode():
208
- generation = await asyncio.to_thread(
209
- self.model.generate,
210
  **inputs_vlm,
211
  max_new_tokens=512,
212
  do_sample=True,
@@ -241,15 +232,13 @@ class LLMManager:
241
  logger.info("No valid image provided, processing text only")
242
 
243
  try:
244
- inputs_vlm = await asyncio.to_thread(
245
- self.processor.apply_chat_template,
246
  messages_vlm,
247
  add_generation_prompt=True,
248
  tokenize=True,
249
  return_dict=True,
250
  return_tensors="pt"
251
- )
252
- inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
253
  logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
254
  except Exception as e:
255
  logger.error(f"Error in apply_chat_template: {str(e)}")
@@ -258,8 +247,7 @@ class LLMManager:
258
  input_len = inputs_vlm["input_ids"].shape[-1]
259
 
260
  with torch.inference_mode():
261
- generation = await asyncio.to_thread(
262
- self.model.generate,
263
  **inputs_vlm,
264
  max_new_tokens=512,
265
  do_sample=True,
@@ -271,7 +259,7 @@ class LLMManager:
271
  logger.info(f"Chat_v2 response: {decoded}")
272
  return decoded
273
 
274
- # TTS Manager
275
  class TTSManager:
276
  def __init__(self, device_type=device):
277
  self.device_type = device_type
@@ -360,32 +348,38 @@ SUPPORTED_LANGUAGES = {
360
  "por_Latn", "rus_Cyrl", "pol_Latn"
361
  }
362
 
363
- # Translation Manager
364
  class TranslateManager:
365
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
366
  self.device_type = device_type
367
- self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
 
 
 
 
368
 
369
- def initialize_model(self, src_lang, tgt_lang, use_distilled):
370
- if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
371
- model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
372
- elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
373
- model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
374
- elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
375
- model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
376
- else:
377
- raise ValueError("Invalid language combination: English to English translation is not supported.")
378
-
379
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
380
- model = AutoModelForSeq2SeqLM.from_pretrained(
381
- model_name,
382
- trust_remote_code=True,
383
- torch_dtype=torch.float16,
384
- attn_implementation="flash_attention_2"
385
- ).to(self.device_type)
386
- model = torch.compile(model, mode="reduce-overhead")
387
- print("Model compiled with torch.compile")
388
- return tokenizer, model
 
 
389
 
390
  class ModelManager:
391
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
@@ -396,24 +390,9 @@ class ModelManager:
396
 
397
  async def load_model(self, src_lang, tgt_lang, key):
398
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
399
- if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
400
- model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
401
- elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
402
- model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
403
- else:
404
- model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
405
-
406
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
407
- model = await asyncio.to_thread(
408
- AutoModelForSeq2SeqLM.from_pretrained,
409
- model_name,
410
- trust_remote_code=True,
411
- torch_dtype=torch.float16,
412
- attn_implementation="flash_attention_2"
413
- )
414
- model = model.to(self.device_type)
415
- model = torch.compile(model, mode="reduce-overhead")
416
- self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
417
  logger.info(f"Loaded translation model for {key}")
418
 
419
  def get_model(self, src_lang, tgt_lang):
@@ -434,7 +413,7 @@ class ModelManager:
434
  return 'indic_indic'
435
  raise ValueError("Invalid language combination")
436
 
437
- # ASR Manager
438
  class ASRModelManager:
439
  def __init__(self, device_type="cuda"):
440
  self.device_type = device_type
@@ -483,12 +462,12 @@ class TranslationRequest(BaseModel):
483
  src_lang: str
484
  tgt_lang: str
485
 
486
- class TranslationResponse(BaseModel):
487
- translations: List[str]
488
-
489
  class TranscriptionResponse(BaseModel):
490
  text: str
491
 
 
 
 
492
  # Dependency
493
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
494
  return model_manager.get_model(src_lang, tgt_lang)
@@ -519,7 +498,7 @@ async def lifespan(app: FastAPI):
519
  logger.info("Starting model loading in background...")
520
  asyncio.create_task(load_all_models())
521
  yield
522
- await llm_manager.unload()
523
  logger.info("Server shutdown complete")
524
 
525
  # FastAPI App
@@ -604,6 +583,8 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
604
 
605
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
606
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
 
 
607
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
608
  response = await translate(request, translate_manager)
609
  return response.translations
@@ -620,7 +601,7 @@ async def home():
620
  async def unload_all_models():
621
  try:
622
  logger.info("Starting to unload all models...")
623
- await llm_manager.unload()
624
  logger.info("All models unloaded successfully")
625
  return {"status": "success", "message": "All models unloaded"}
626
  except Exception as e:
@@ -631,12 +612,12 @@ async def unload_all_models():
631
  async def load_all_models():
632
  try:
633
  logger.info("Starting to load all models...")
634
- await llm_manager.load()
635
  logger.info("All models loaded successfully")
636
  return {"status": "success", "message": "All models loaded"}
637
  except Exception as e:
638
  logger.error(f"Error loading models: {str(e)}")
639
- raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
640
 
641
  @app.post("/v1/translate", response_model=TranslationResponse)
642
  async def translate_endpoint(request: TranslationRequest):
@@ -826,7 +807,7 @@ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(.
826
  if not asr_manager.model:
827
  raise HTTPException(status_code=503, detail="ASR model still loading, please try again later")
828
  try:
829
- import torchaudio # Added here for clarity
830
  wav, sr = torchaudio.load(file.file)
831
  wav = torch.mean(wav, dim=0, keepdim=True)
832
  target_sample_rate = 16000
 
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, Gemma3ForConditionalGeneration
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
 
68
 
69
  settings = Settings()
70
 
71
+ # Quantization config for LLM (unchanged from gemma_llm.py)
72
  quantization_config = BitsAndBytesConfig(
73
  load_in_4bit=True,
74
  bnb_4bit_quant_type="nf4",
 
76
  bnb_4bit_compute_dtype=torch.bfloat16
77
  )
78
 
79
+ # LLM Manager (from gemma_llm.py with async load)
80
  class LLMManager:
81
  def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
82
  self.model_name = model_name
 
87
  self.processor = None
88
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  async def load(self):
91
  if not self.is_loaded:
92
  try:
93
  self.model = await asyncio.to_thread(
94
+ Gemma3ForConditionalGeneration.from_pretrained,
95
  self.model_name,
96
  device_map="auto",
97
  quantization_config=quantization_config,
 
105
  logger.error(f"Failed to load model: {str(e)}")
106
  raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
107
 
108
+ def unload(self):
109
+ if self.is_loaded:
110
+ del self.model
111
+ del self.processor
112
+ if self.device.type == "cuda":
113
+ torch.cuda.empty_cache()
114
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
115
+ self.is_loaded = False
116
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
117
+
118
  async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
120
  await self.load()
 
131
  ]
132
 
133
  try:
134
+ inputs_vlm = self.processor.apply_chat_template(
 
135
  messages_vlm,
136
  add_generation_prompt=True,
137
  tokenize=True,
138
  return_dict=True,
139
  return_tensors="pt"
140
+ ).to(self.device, dtype=torch.bfloat16)
 
141
  logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
142
  logger.info(f"Decoded input: {self.processor.decode(inputs_vlm['input_ids'][0])}")
143
  except Exception as e:
 
147
  input_len = inputs_vlm["input_ids"].shape[-1]
148
 
149
  with torch.inference_mode():
150
+ generation = self.model.generate(
 
151
  **inputs_vlm,
152
  max_new_tokens=max_tokens,
153
  do_sample=True,
 
182
  logger.info("No valid image provided, processing text only")
183
 
184
  try:
185
+ inputs_vlm = self.processor.apply_chat_template(
 
186
  messages_vlm,
187
  add_generation_prompt=True,
188
  tokenize=True,
189
  return_dict=True,
190
  return_tensors="pt"
191
+ ).to(self.device, dtype=torch.bfloat16)
 
192
  logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
193
  except Exception as e:
194
  logger.error(f"Error in apply_chat_template: {str(e)}")
 
197
  input_len = inputs_vlm["input_ids"].shape[-1]
198
 
199
  with torch.inference_mode():
200
+ generation = self.model.generate(
 
201
  **inputs_vlm,
202
  max_new_tokens=512,
203
  do_sample=True,
 
232
  logger.info("No valid image provided, processing text only")
233
 
234
  try:
235
+ inputs_vlm = self.processor.apply_chat_template(
 
236
  messages_vlm,
237
  add_generation_prompt=True,
238
  tokenize=True,
239
  return_dict=True,
240
  return_tensors="pt"
241
+ ).to(self.device, dtype=torch.bfloat16)
 
242
  logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
243
  except Exception as e:
244
  logger.error(f"Error in apply_chat_template: {str(e)}")
 
247
  input_len = inputs_vlm["input_ids"].shape[-1]
248
 
249
  with torch.inference_mode():
250
+ generation = self.model.generate(
 
251
  **inputs_vlm,
252
  max_new_tokens=512,
253
  do_sample=True,
 
259
  logger.info(f"Chat_v2 response: {decoded}")
260
  return decoded
261
 
262
+ # TTS Manager (async load)
263
  class TTSManager:
264
  def __init__(self, device_type=device):
265
  self.device_type = device_type
 
348
  "por_Latn", "rus_Cyrl", "pol_Latn"
349
  }
350
 
351
+ # Translation Manager (async load)
352
  class TranslateManager:
353
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
354
  self.device_type = device_type
355
+ self.tokenizer = None
356
+ self.model = None
357
+ self.src_lang = src_lang
358
+ self.tgt_lang = tgt_lang
359
+ self.use_distilled = use_distilled
360
 
361
+ async def load(self):
362
+ if not self.tokenizer or not self.model:
363
+ if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
364
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
365
+ elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
366
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
367
+ elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
368
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
369
+ else:
370
+ raise ValueError("Invalid language combination: English to English translation is not supported.")
371
+
372
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
373
+ self.model = await asyncio.to_thread(
374
+ AutoModelForSeq2SeqLM.from_pretrained,
375
+ model_name,
376
+ trust_remote_code=True,
377
+ torch_dtype=torch.float16,
378
+ attn_implementation="flash_attention_2"
379
+ )
380
+ self.model = self.model.to(self.device_type)
381
+ self.model = torch.compile(self.model, mode="reduce-overhead")
382
+ logger.info(f"Translation model {model_name} loaded for {self.src_lang} -> {self.tgt_lang}")
383
 
384
  class ModelManager:
385
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
 
390
 
391
  async def load_model(self, src_lang, tgt_lang, key):
392
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
393
+ translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
394
+ await translate_manager.load()
395
+ self.models[key] = translate_manager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  logger.info(f"Loaded translation model for {key}")
397
 
398
  def get_model(self, src_lang, tgt_lang):
 
413
  return 'indic_indic'
414
  raise ValueError("Invalid language combination")
415
 
416
+ # ASR Manager (async load)
417
  class ASRModelManager:
418
  def __init__(self, device_type="cuda"):
419
  self.device_type = device_type
 
462
  src_lang: str
463
  tgt_lang: str
464
 
 
 
 
465
  class TranscriptionResponse(BaseModel):
466
  text: str
467
 
468
+ class TranslationResponse(BaseModel):
469
+ translations: List[str]
470
+
471
  # Dependency
472
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
473
  return model_manager.get_model(src_lang, tgt_lang)
 
498
  logger.info("Starting model loading in background...")
499
  asyncio.create_task(load_all_models())
500
  yield
501
+ llm_manager.unload() # Synchronous unload as per original gemma_llm.py
502
  logger.info("Server shutdown complete")
503
 
504
  # FastAPI App
 
583
 
584
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
585
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
586
+ if not translate_manager.model: # Ensure model is loaded
587
+ await translate_manager.load()
588
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
589
  response = await translate(request, translate_manager)
590
  return response.translations
 
601
  async def unload_all_models():
602
  try:
603
  logger.info("Starting to unload all models...")
604
+ llm_manager.unload() # Synchronous as per original
605
  logger.info("All models unloaded successfully")
606
  return {"status": "success", "message": "All models unloaded"}
607
  except Exception as e:
 
612
  async def load_all_models():
613
  try:
614
  logger.info("Starting to load all models...")
615
+ await llm_manager.load() # Async load
616
  logger.info("All models loaded successfully")
617
  return {"status": "success", "message": "All models loaded"}
618
  except Exception as e:
619
  logger.error(f"Error loading models: {str(e)}")
620
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
621
 
622
  @app.post("/v1/translate", response_model=TranslationResponse)
623
  async def translate_endpoint(request: TranslationRequest):
 
807
  if not asr_manager.model:
808
  raise HTTPException(status_code=503, detail="ASR model still loading, please try again later")
809
  try:
810
+ import torchaudio
811
  wav, sr = torchaudio.load(file.file)
812
  wav = torch.mean(wav, dim=0, keepdim=True)
813
  target_sample_rate = 16000