sachin
commited on
Commit
·
a9b565f
1
Parent(s):
ba89109
test
Browse files- src/server/main.py +66 -52
src/server/main.py
CHANGED
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
|
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
16 |
import torch
|
17 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig,
|
18 |
from IndicTransToolkit import IndicProcessor
|
19 |
import json
|
20 |
import asyncio
|
@@ -84,8 +84,8 @@ class LLMManager:
|
|
84 |
self.device = torch.device(device)
|
85 |
self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
|
86 |
self.model = None
|
87 |
-
self.is_loaded = False
|
88 |
self.processor = None
|
|
|
89 |
logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
|
90 |
|
91 |
async def load(self):
|
@@ -99,12 +99,15 @@ class LLMManager:
|
|
99 |
torch_dtype=self.torch_dtype
|
100 |
)
|
101 |
self.model.eval()
|
102 |
-
self.processor = await asyncio.to_thread(
|
|
|
|
|
|
|
103 |
self.is_loaded = True
|
104 |
-
logger.info(f"LLM {self.model_name} loaded on {self.device}
|
105 |
except Exception as e:
|
106 |
-
logger.error(f"Failed to load
|
107 |
-
raise
|
108 |
|
109 |
def unload(self):
|
110 |
if self.is_loaded:
|
@@ -139,8 +142,6 @@ class LLMManager:
|
|
139 |
return_dict=True,
|
140 |
return_tensors="pt"
|
141 |
).to(self.device, dtype=torch.bfloat16)
|
142 |
-
logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
|
143 |
-
logger.info(f"Decoded input: {self.processor.decode(inputs_vlm['input_ids'][0])}")
|
144 |
except Exception as e:
|
145 |
logger.error(f"Error in tokenization: {str(e)}")
|
146 |
raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
|
@@ -190,7 +191,6 @@ class LLMManager:
|
|
190 |
return_dict=True,
|
191 |
return_tensors="pt"
|
192 |
).to(self.device, dtype=torch.bfloat16)
|
193 |
-
logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
|
194 |
except Exception as e:
|
195 |
logger.error(f"Error in apply_chat_template: {str(e)}")
|
196 |
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
|
@@ -240,7 +240,6 @@ class LLMManager:
|
|
240 |
return_dict=True,
|
241 |
return_tensors="pt"
|
242 |
).to(self.device, dtype=torch.bfloat16)
|
243 |
-
logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
|
244 |
except Exception as e:
|
245 |
logger.error(f"Error in apply_chat_template: {str(e)}")
|
246 |
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
|
@@ -268,14 +267,15 @@ class TTSManager:
|
|
268 |
self.repo_id = "ai4bharat/IndicF5"
|
269 |
|
270 |
async def load(self):
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
|
|
279 |
|
280 |
def synthesize(self, text, ref_audio_path, ref_text):
|
281 |
if not self.model:
|
@@ -368,9 +368,13 @@ class TranslateManager:
|
|
368 |
elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
369 |
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
|
370 |
else:
|
371 |
-
raise ValueError("Invalid language combination
|
372 |
|
373 |
-
self.tokenizer =
|
|
|
|
|
|
|
|
|
374 |
self.model = await asyncio.to_thread(
|
375 |
AutoModelForSeq2SeqLM.from_pretrained,
|
376 |
model_name,
|
@@ -380,7 +384,7 @@ class TranslateManager:
|
|
380 |
)
|
381 |
self.model = self.model.to(self.device_type)
|
382 |
self.model = torch.compile(self.model, mode="reduce-overhead")
|
383 |
-
logger.info(f"Translation model {model_name} loaded
|
384 |
|
385 |
class ModelManager:
|
386 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
@@ -390,11 +394,11 @@ class ModelManager:
|
|
390 |
self.is_lazy_loading = is_lazy_loading
|
391 |
|
392 |
async def load_model(self, src_lang, tgt_lang, key):
|
393 |
-
logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
|
394 |
translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
|
395 |
await translate_manager.load()
|
396 |
self.models[key] = translate_manager
|
397 |
-
logger.info(f"Loaded translation model for {key}")
|
398 |
|
399 |
def get_model(self, src_lang, tgt_lang):
|
400 |
key = self._get_model_key(src_lang, tgt_lang)
|
@@ -422,13 +426,15 @@ class ASRModelManager:
|
|
422 |
self.model_language = {"kannada": "kn"}
|
423 |
|
424 |
async def load(self):
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
|
|
|
|
432 |
|
433 |
# Global Managers
|
434 |
llm_manager = LLMManager(settings.llm_model_name)
|
@@ -479,25 +485,33 @@ translation_configs = []
|
|
479 |
@asynccontextmanager
|
480 |
async def lifespan(app: FastAPI):
|
481 |
async def load_all_models():
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
|
499 |
-
logger.info("Starting model loading
|
500 |
-
|
501 |
yield
|
502 |
llm_manager.unload()
|
503 |
logger.info("Server shutdown complete")
|
@@ -526,7 +540,7 @@ app.state.limiter = limiter
|
|
526 |
@app.post("/audio/speech", response_class=StreamingResponse)
|
527 |
async def synthesize_kannada(request: KannadaSynthesizeRequest):
|
528 |
if not tts_manager.model:
|
529 |
-
raise HTTPException(status_code=503, detail="TTS model
|
530 |
kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
|
531 |
if not request.text.strip():
|
532 |
raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
|
@@ -591,7 +605,7 @@ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_
|
|
591 |
await model_manager.load_model(src_lang, tgt_lang, key)
|
592 |
translate_manager = model_manager.get_model(src_lang, tgt_lang)
|
593 |
|
594 |
-
if not translate_manager.model:
|
595 |
await translate_manager.load()
|
596 |
|
597 |
request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
@@ -814,7 +828,7 @@ async def chat_v2(
|
|
814 |
@app.post("/transcribe/", response_model=TranscriptionResponse)
|
815 |
async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
816 |
if not asr_manager.model:
|
817 |
-
raise HTTPException(status_code=503, detail="ASR model
|
818 |
try:
|
819 |
wav, sr = torchaudio.load(file.file)
|
820 |
wav = torch.mean(wav, dim=0, keepdim=True)
|
@@ -835,7 +849,7 @@ async def speech_to_speech(
|
|
835 |
language: str = Query(..., enum=list(asr_manager.model_language.keys())),
|
836 |
) -> StreamingResponse:
|
837 |
if not tts_manager.model:
|
838 |
-
raise HTTPException(status_code=503, detail="TTS model
|
839 |
transcription = await transcribe_audio(file, language)
|
840 |
logger.info(f"Transcribed text: {transcription.text}")
|
841 |
|
|
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
16 |
import torch
|
17 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
|
18 |
from IndicTransToolkit import IndicProcessor
|
19 |
import json
|
20 |
import asyncio
|
|
|
84 |
self.device = torch.device(device)
|
85 |
self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
|
86 |
self.model = None
|
|
|
87 |
self.processor = None
|
88 |
+
self.is_loaded = False
|
89 |
logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
|
90 |
|
91 |
async def load(self):
|
|
|
99 |
torch_dtype=self.torch_dtype
|
100 |
)
|
101 |
self.model.eval()
|
102 |
+
self.processor = await asyncio.to_thread(
|
103 |
+
AutoProcessor.from_pretrained,
|
104 |
+
self.model_name
|
105 |
+
)
|
106 |
self.is_loaded = True
|
107 |
+
logger.info(f"LLM {self.model_name} loaded asynchronously on {self.device}")
|
108 |
except Exception as e:
|
109 |
+
logger.error(f"Failed to load LLM: {str(e)}")
|
110 |
+
raise
|
111 |
|
112 |
def unload(self):
|
113 |
if self.is_loaded:
|
|
|
142 |
return_dict=True,
|
143 |
return_tensors="pt"
|
144 |
).to(self.device, dtype=torch.bfloat16)
|
|
|
|
|
145 |
except Exception as e:
|
146 |
logger.error(f"Error in tokenization: {str(e)}")
|
147 |
raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
|
|
|
191 |
return_dict=True,
|
192 |
return_tensors="pt"
|
193 |
).to(self.device, dtype=torch.bfloat16)
|
|
|
194 |
except Exception as e:
|
195 |
logger.error(f"Error in apply_chat_template: {str(e)}")
|
196 |
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
|
|
|
240 |
return_dict=True,
|
241 |
return_tensors="pt"
|
242 |
).to(self.device, dtype=torch.bfloat16)
|
|
|
243 |
except Exception as e:
|
244 |
logger.error(f"Error in apply_chat_template: {str(e)}")
|
245 |
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
|
|
|
267 |
self.repo_id = "ai4bharat/IndicF5"
|
268 |
|
269 |
async def load(self):
|
270 |
+
if not self.model:
|
271 |
+
logger.info("Loading TTS model IndicF5 asynchronously...")
|
272 |
+
self.model = await asyncio.to_thread(
|
273 |
+
AutoModel.from_pretrained,
|
274 |
+
self.repo_id,
|
275 |
+
trust_remote_code=True
|
276 |
+
)
|
277 |
+
self.model = self.model.to(self.device_type)
|
278 |
+
logger.info("TTS model IndicF5 loaded asynchronously")
|
279 |
|
280 |
def synthesize(self, text, ref_audio_path, ref_text):
|
281 |
if not self.model:
|
|
|
368 |
elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
369 |
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
|
370 |
else:
|
371 |
+
raise ValueError("Invalid language combination")
|
372 |
|
373 |
+
self.tokenizer = await asyncio.to_thread(
|
374 |
+
AutoTokenizer.from_pretrained,
|
375 |
+
model_name,
|
376 |
+
trust_remote_code=True
|
377 |
+
)
|
378 |
self.model = await asyncio.to_thread(
|
379 |
AutoModelForSeq2SeqLM.from_pretrained,
|
380 |
model_name,
|
|
|
384 |
)
|
385 |
self.model = self.model.to(self.device_type)
|
386 |
self.model = torch.compile(self.model, mode="reduce-overhead")
|
387 |
+
logger.info(f"Translation model {model_name} loaded asynchronously")
|
388 |
|
389 |
class ModelManager:
|
390 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
|
|
394 |
self.is_lazy_loading = is_lazy_loading
|
395 |
|
396 |
async def load_model(self, src_lang, tgt_lang, key):
|
397 |
+
logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} asynchronously")
|
398 |
translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
|
399 |
await translate_manager.load()
|
400 |
self.models[key] = translate_manager
|
401 |
+
logger.info(f"Loaded translation model for {key} asynchronously")
|
402 |
|
403 |
def get_model(self, src_lang, tgt_lang):
|
404 |
key = self._get_model_key(src_lang, tgt_lang)
|
|
|
426 |
self.model_language = {"kannada": "kn"}
|
427 |
|
428 |
async def load(self):
|
429 |
+
if not self.model:
|
430 |
+
logger.info("Loading ASR model asynchronously...")
|
431 |
+
self.model = await asyncio.to_thread(
|
432 |
+
AutoModel.from_pretrained,
|
433 |
+
"ai4bharat/indic-conformer-600m-multilingual",
|
434 |
+
trust_remote_code=True
|
435 |
+
)
|
436 |
+
self.model = self.model.to(self.device_type)
|
437 |
+
logger.info("ASR model loaded asynchronously")
|
438 |
|
439 |
# Global Managers
|
440 |
llm_manager = LLMManager(settings.llm_model_name)
|
|
|
485 |
@asynccontextmanager
|
486 |
async def lifespan(app: FastAPI):
|
487 |
async def load_all_models():
|
488 |
+
try:
|
489 |
+
tasks = [
|
490 |
+
llm_manager.load(),
|
491 |
+
tts_manager.load(),
|
492 |
+
asr_manager.load(),
|
493 |
+
]
|
494 |
+
|
495 |
+
translation_tasks = [
|
496 |
+
model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic'),
|
497 |
+
model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng'),
|
498 |
+
model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic'),
|
499 |
+
]
|
500 |
+
|
501 |
+
for config in translation_configs:
|
502 |
+
src_lang = config["src_lang"]
|
503 |
+
tgt_lang = config["tgt_lang"]
|
504 |
+
key = model_manager._get_model_key(src_lang, tgt_lang)
|
505 |
+
translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
|
506 |
+
|
507 |
+
await asyncio.gather(*tasks, *translation_tasks)
|
508 |
+
logger.info("All models loaded successfully asynchronously")
|
509 |
+
except Exception as e:
|
510 |
+
logger.error(f"Error loading models: {str(e)}")
|
511 |
+
raise
|
512 |
|
513 |
+
logger.info("Starting asynchronous model loading...")
|
514 |
+
await load_all_models()
|
515 |
yield
|
516 |
llm_manager.unload()
|
517 |
logger.info("Server shutdown complete")
|
|
|
540 |
@app.post("/audio/speech", response_class=StreamingResponse)
|
541 |
async def synthesize_kannada(request: KannadaSynthesizeRequest):
|
542 |
if not tts_manager.model:
|
543 |
+
raise HTTPException(status_code=503, detail="TTS model not loaded")
|
544 |
kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
|
545 |
if not request.text.strip():
|
546 |
raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
|
|
|
605 |
await model_manager.load_model(src_lang, tgt_lang, key)
|
606 |
translate_manager = model_manager.get_model(src_lang, tgt_lang)
|
607 |
|
608 |
+
if not translate_manager.model:
|
609 |
await translate_manager.load()
|
610 |
|
611 |
request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
|
|
828 |
@app.post("/transcribe/", response_model=TranscriptionResponse)
|
829 |
async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
830 |
if not asr_manager.model:
|
831 |
+
raise HTTPException(status_code=503, detail="ASR model not loaded")
|
832 |
try:
|
833 |
wav, sr = torchaudio.load(file.file)
|
834 |
wav = torch.mean(wav, dim=0, keepdim=True)
|
|
|
849 |
language: str = Query(..., enum=list(asr_manager.model_language.keys())),
|
850 |
) -> StreamingResponse:
|
851 |
if not tts_manager.model:
|
852 |
+
raise HTTPException(status_code=503, detail="TTS model not loaded")
|
853 |
transcription = await transcribe_audio(file, language)
|
854 |
logger.info(f"Transcribed text: {transcription.text}")
|
855 |
|