sachin
commited on
Commit
·
cb770cb
1
Parent(s):
2472b8d
fix
Browse files- src/server/main.py +17 -9
src/server/main.py
CHANGED
@@ -68,7 +68,7 @@ class Settings(BaseSettings):
|
|
68 |
|
69 |
settings = Settings()
|
70 |
|
71 |
-
# Quantization config for LLM
|
72 |
quantization_config = BitsAndBytesConfig(
|
73 |
load_in_4bit=True,
|
74 |
bnb_4bit_quant_type="nf4",
|
@@ -76,7 +76,7 @@ quantization_config = BitsAndBytesConfig(
|
|
76 |
bnb_4bit_compute_dtype=torch.bfloat16
|
77 |
)
|
78 |
|
79 |
-
# LLM Manager
|
80 |
class LLMManager:
|
81 |
def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
82 |
self.model_name = model_name
|
@@ -259,7 +259,7 @@ class LLMManager:
|
|
259 |
logger.info(f"Chat_v2 response: {decoded}")
|
260 |
return decoded
|
261 |
|
262 |
-
# TTS Manager
|
263 |
class TTSManager:
|
264 |
def __init__(self, device_type=device):
|
265 |
self.device_type = device_type
|
@@ -348,7 +348,7 @@ SUPPORTED_LANGUAGES = {
|
|
348 |
"por_Latn", "rus_Cyrl", "pol_Latn"
|
349 |
}
|
350 |
|
351 |
-
# Translation Manager
|
352 |
class TranslateManager:
|
353 |
def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
|
354 |
self.device_type = device_type
|
@@ -413,7 +413,7 @@ class ModelManager:
|
|
413 |
return 'indic_indic'
|
414 |
raise ValueError("Invalid language combination")
|
415 |
|
416 |
-
# ASR Manager
|
417 |
class ASRModelManager:
|
418 |
def __init__(self, device_type="cuda"):
|
419 |
self.device_type = device_type
|
@@ -498,7 +498,7 @@ async def lifespan(app: FastAPI):
|
|
498 |
logger.info("Starting model loading in background...")
|
499 |
asyncio.create_task(load_all_models())
|
500 |
yield
|
501 |
-
llm_manager.unload()
|
502 |
logger.info("Server shutdown complete")
|
503 |
|
504 |
# FastAPI App
|
@@ -582,9 +582,17 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
|
|
582 |
return TranslationResponse(translations=translations)
|
583 |
|
584 |
async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
if not translate_manager.model: # Ensure model is loaded
|
587 |
await translate_manager.load()
|
|
|
588 |
request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
589 |
response = await translate(request, translate_manager)
|
590 |
return response.translations
|
@@ -601,7 +609,7 @@ async def home():
|
|
601 |
async def unload_all_models():
|
602 |
try:
|
603 |
logger.info("Starting to unload all models...")
|
604 |
-
llm_manager.unload()
|
605 |
logger.info("All models unloaded successfully")
|
606 |
return {"status": "success", "message": "All models unloaded"}
|
607 |
except Exception as e:
|
@@ -612,7 +620,7 @@ async def unload_all_models():
|
|
612 |
async def load_all_models():
|
613 |
try:
|
614 |
logger.info("Starting to load all models...")
|
615 |
-
await llm_manager.load()
|
616 |
logger.info("All models loaded successfully")
|
617 |
return {"status": "success", "message": "All models loaded"}
|
618 |
except Exception as e:
|
|
|
68 |
|
69 |
settings = Settings()
|
70 |
|
71 |
+
# Quantization config for LLM
|
72 |
quantization_config = BitsAndBytesConfig(
|
73 |
load_in_4bit=True,
|
74 |
bnb_4bit_quant_type="nf4",
|
|
|
76 |
bnb_4bit_compute_dtype=torch.bfloat16
|
77 |
)
|
78 |
|
79 |
+
# LLM Manager
|
80 |
class LLMManager:
|
81 |
def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
82 |
self.model_name = model_name
|
|
|
259 |
logger.info(f"Chat_v2 response: {decoded}")
|
260 |
return decoded
|
261 |
|
262 |
+
# TTS Manager
|
263 |
class TTSManager:
|
264 |
def __init__(self, device_type=device):
|
265 |
self.device_type = device_type
|
|
|
348 |
"por_Latn", "rus_Cyrl", "pol_Latn"
|
349 |
}
|
350 |
|
351 |
+
# Translation Manager
|
352 |
class TranslateManager:
|
353 |
def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
|
354 |
self.device_type = device_type
|
|
|
413 |
return 'indic_indic'
|
414 |
raise ValueError("Invalid language combination")
|
415 |
|
416 |
+
# ASR Manager
|
417 |
class ASRModelManager:
|
418 |
def __init__(self, device_type="cuda"):
|
419 |
self.device_type = device_type
|
|
|
498 |
logger.info("Starting model loading in background...")
|
499 |
asyncio.create_task(load_all_models())
|
500 |
yield
|
501 |
+
llm_manager.unload()
|
502 |
logger.info("Server shutdown complete")
|
503 |
|
504 |
# FastAPI App
|
|
|
582 |
return TranslationResponse(translations=translations)
|
583 |
|
584 |
async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
585 |
+
try:
|
586 |
+
translate_manager = model_manager.get_model(src_lang, tgt_lang)
|
587 |
+
except ValueError as e:
|
588 |
+
logger.info(f"Model not preloaded: {str(e)}, loading now...")
|
589 |
+
key = model_manager._get_model_key(src_lang, tgt_lang)
|
590 |
+
await model_manager.load_model(src_lang, tgt_lang, key)
|
591 |
+
translate_manager = model_manager.get_model(src_lang, tgt_lang)
|
592 |
+
|
593 |
if not translate_manager.model: # Ensure model is loaded
|
594 |
await translate_manager.load()
|
595 |
+
|
596 |
request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
597 |
response = await translate(request, translate_manager)
|
598 |
return response.translations
|
|
|
609 |
async def unload_all_models():
|
610 |
try:
|
611 |
logger.info("Starting to unload all models...")
|
612 |
+
llm_manager.unload()
|
613 |
logger.info("All models unloaded successfully")
|
614 |
return {"status": "success", "message": "All models unloaded"}
|
615 |
except Exception as e:
|
|
|
620 |
async def load_all_models():
|
621 |
try:
|
622 |
logger.info("Starting to load all models...")
|
623 |
+
await llm_manager.load()
|
624 |
logger.info("All models loaded successfully")
|
625 |
return {"status": "success", "message": "All models loaded"}
|
626 |
except Exception as e:
|