sachin commited on
Commit
cb770cb
·
1 Parent(s): 2472b8d
Files changed (1) hide show
  1. src/server/main.py +17 -9
src/server/main.py CHANGED
@@ -68,7 +68,7 @@ class Settings(BaseSettings):
68
 
69
  settings = Settings()
70
 
71
- # Quantization config for LLM (unchanged from gemma_llm.py)
72
  quantization_config = BitsAndBytesConfig(
73
  load_in_4bit=True,
74
  bnb_4bit_quant_type="nf4",
@@ -76,7 +76,7 @@ quantization_config = BitsAndBytesConfig(
76
  bnb_4bit_compute_dtype=torch.bfloat16
77
  )
78
 
79
- # LLM Manager (from gemma_llm.py with async load)
80
  class LLMManager:
81
  def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
82
  self.model_name = model_name
@@ -259,7 +259,7 @@ class LLMManager:
259
  logger.info(f"Chat_v2 response: {decoded}")
260
  return decoded
261
 
262
- # TTS Manager (async load)
263
  class TTSManager:
264
  def __init__(self, device_type=device):
265
  self.device_type = device_type
@@ -348,7 +348,7 @@ SUPPORTED_LANGUAGES = {
348
  "por_Latn", "rus_Cyrl", "pol_Latn"
349
  }
350
 
351
- # Translation Manager (async load)
352
  class TranslateManager:
353
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
354
  self.device_type = device_type
@@ -413,7 +413,7 @@ class ModelManager:
413
  return 'indic_indic'
414
  raise ValueError("Invalid language combination")
415
 
416
- # ASR Manager (async load)
417
  class ASRModelManager:
418
  def __init__(self, device_type="cuda"):
419
  self.device_type = device_type
@@ -498,7 +498,7 @@ async def lifespan(app: FastAPI):
498
  logger.info("Starting model loading in background...")
499
  asyncio.create_task(load_all_models())
500
  yield
501
- llm_manager.unload() # Synchronous unload as per original gemma_llm.py
502
  logger.info("Server shutdown complete")
503
 
504
  # FastAPI App
@@ -582,9 +582,17 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
582
  return TranslationResponse(translations=translations)
583
 
584
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
585
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
 
 
 
 
 
 
 
586
  if not translate_manager.model: # Ensure model is loaded
587
  await translate_manager.load()
 
588
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
589
  response = await translate(request, translate_manager)
590
  return response.translations
@@ -601,7 +609,7 @@ async def home():
601
  async def unload_all_models():
602
  try:
603
  logger.info("Starting to unload all models...")
604
- llm_manager.unload() # Synchronous as per original
605
  logger.info("All models unloaded successfully")
606
  return {"status": "success", "message": "All models unloaded"}
607
  except Exception as e:
@@ -612,7 +620,7 @@ async def unload_all_models():
612
  async def load_all_models():
613
  try:
614
  logger.info("Starting to load all models...")
615
- await llm_manager.load() # Async load
616
  logger.info("All models loaded successfully")
617
  return {"status": "success", "message": "All models loaded"}
618
  except Exception as e:
 
68
 
69
  settings = Settings()
70
 
71
+ # Quantization config for LLM
72
  quantization_config = BitsAndBytesConfig(
73
  load_in_4bit=True,
74
  bnb_4bit_quant_type="nf4",
 
76
  bnb_4bit_compute_dtype=torch.bfloat16
77
  )
78
 
79
+ # LLM Manager
80
  class LLMManager:
81
  def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
82
  self.model_name = model_name
 
259
  logger.info(f"Chat_v2 response: {decoded}")
260
  return decoded
261
 
262
+ # TTS Manager
263
  class TTSManager:
264
  def __init__(self, device_type=device):
265
  self.device_type = device_type
 
348
  "por_Latn", "rus_Cyrl", "pol_Latn"
349
  }
350
 
351
+ # Translation Manager
352
  class TranslateManager:
353
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
354
  self.device_type = device_type
 
413
  return 'indic_indic'
414
  raise ValueError("Invalid language combination")
415
 
416
+ # ASR Manager
417
  class ASRModelManager:
418
  def __init__(self, device_type="cuda"):
419
  self.device_type = device_type
 
498
  logger.info("Starting model loading in background...")
499
  asyncio.create_task(load_all_models())
500
  yield
501
+ llm_manager.unload()
502
  logger.info("Server shutdown complete")
503
 
504
  # FastAPI App
 
582
  return TranslationResponse(translations=translations)
583
 
584
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
585
+ try:
586
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
587
+ except ValueError as e:
588
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
589
+ key = model_manager._get_model_key(src_lang, tgt_lang)
590
+ await model_manager.load_model(src_lang, tgt_lang, key)
591
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
592
+
593
  if not translate_manager.model: # Ensure model is loaded
594
  await translate_manager.load()
595
+
596
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
597
  response = await translate(request, translate_manager)
598
  return response.translations
 
609
  async def unload_all_models():
610
  try:
611
  logger.info("Starting to unload all models...")
612
+ llm_manager.unload()
613
  logger.info("All models unloaded successfully")
614
  return {"status": "success", "message": "All models unloaded"}
615
  except Exception as e:
 
620
  async def load_all_models():
621
  try:
622
  logger.info("Starting to load all models...")
623
+ await llm_manager.load()
624
  logger.info("All models loaded successfully")
625
  return {"status": "success", "message": "All models loaded"}
626
  except Exception as e: