sachin
commited on
Commit
·
4d3fcd9
1
Parent(s):
460983d
fix-changes
Browse files- Dockerfile +0 -1
- Dockerfile.base +0 -1
- requirements.txt +1 -1
- src/server/main.py +83 -100
Dockerfile
CHANGED
@@ -6,6 +6,5 @@ COPY . .
|
|
6 |
ENV HF_HOME=/data/huggingface
|
7 |
# Expose port
|
8 |
EXPOSE 7860
|
9 |
-
RUN pip install torchvision
|
10 |
# Start the server
|
11 |
CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
|
|
|
6 |
ENV HF_HOME=/data/huggingface
|
7 |
# Expose port
|
8 |
EXPOSE 7860
|
|
|
9 |
# Start the server
|
10 |
CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
|
Dockerfile.base
CHANGED
@@ -29,7 +29,6 @@ COPY requirements.txt .
|
|
29 |
|
30 |
# Install Python dependencies
|
31 |
RUN pip install --no-cache-dir -r requirements.txt
|
32 |
-
|
33 |
# Set up user
|
34 |
RUN useradd -ms /bin/bash appuser \
|
35 |
&& chown -R appuser:appuser /app
|
|
|
29 |
|
30 |
# Install Python dependencies
|
31 |
RUN pip install --no-cache-dir -r requirements.txt
|
|
|
32 |
# Set up user
|
33 |
RUN useradd -ms /bin/bash appuser \
|
34 |
&& chown -R appuser:appuser /app
|
requirements.txt
CHANGED
@@ -176,7 +176,7 @@ torch==2.6.0
|
|
176 |
torchaudio==2.6.0
|
177 |
torchdiffeq==0.2.5
|
178 |
tqdm==4.67.1
|
179 |
-
transformers
|
180 |
transformers-stream-generator==0.0.5
|
181 |
triton==3.2.0
|
182 |
typer==0.15.2
|
|
|
176 |
torchaudio==2.6.0
|
177 |
torchdiffeq==0.2.5
|
178 |
tqdm==4.67.1
|
179 |
+
transformers==4.50.3
|
180 |
transformers-stream-generator==0.0.5
|
181 |
triton==3.2.0
|
182 |
typer==0.15.2
|
src/server/main.py
CHANGED
@@ -387,68 +387,61 @@ SUPPORTED_LANGUAGES = {
|
|
387 |
class TranslateManager:
|
388 |
def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
|
389 |
self.device_type = device_type
|
390 |
-
self.tokenizer =
|
391 |
-
self.model = None
|
392 |
-
self.src_lang = src_lang
|
393 |
-
self.tgt_lang = tgt_lang
|
394 |
-
self.use_distilled = use_distilled
|
395 |
|
396 |
-
def
|
397 |
-
if
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
raise ValueError("Invalid language combination")
|
406 |
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
attn_implementation="flash_attention_2"
|
416 |
-
)
|
417 |
-
self.model = self.model.to(self.device_type)
|
418 |
-
self.model = torch.compile(self.model, mode="reduce-overhead")
|
419 |
-
logger.info(f"Translation model {model_name} loaded")
|
420 |
|
421 |
class ModelManager:
|
422 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
423 |
-
self.models = {}
|
424 |
self.device_type = device_type
|
425 |
self.use_distilled = use_distilled
|
426 |
self.is_lazy_loading = is_lazy_loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
|
428 |
-
def
|
429 |
-
logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
|
430 |
-
translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
|
431 |
-
translate_manager.load()
|
432 |
-
self.models[key] = translate_manager
|
433 |
-
logger.info(f"Loaded translation model for {key}")
|
434 |
-
|
435 |
-
def get_model(self, src_lang, tgt_lang):
|
436 |
-
key = self._get_model_key(src_lang, tgt_lang)
|
437 |
-
if key not in self.models:
|
438 |
-
if self.is_lazy_loading:
|
439 |
-
self.load_model(src_lang, tgt_lang, key)
|
440 |
-
else:
|
441 |
-
raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
|
442 |
-
return self.models.get(key)
|
443 |
-
|
444 |
-
def _get_model_key(self, src_lang, tgt_lang):
|
445 |
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
446 |
-
|
447 |
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
|
448 |
-
|
449 |
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
450 |
-
|
451 |
-
|
|
|
|
|
|
|
|
|
|
|
452 |
|
453 |
# ASR Manager
|
454 |
class ASRModelManager:
|
@@ -510,6 +503,41 @@ class TranslationResponse(BaseModel):
|
|
510 |
def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
|
511 |
return model_manager.get_model(src_lang, tgt_lang)
|
512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
# Lifespan Event Handler
|
514 |
translation_configs = []
|
515 |
|
@@ -532,23 +560,8 @@ async def lifespan(app: FastAPI):
|
|
532 |
asr_manager.load()
|
533 |
logger.info("ASR model loaded successfully")
|
534 |
|
535 |
-
#
|
536 |
-
|
537 |
-
('eng_Latn', 'kan_Knda', 'eng_indic'),
|
538 |
-
('kan_Knda', 'eng_Latn', 'indic_eng'),
|
539 |
-
('kan_Knda', 'hin_Deva', 'indic_indic'),
|
540 |
-
]
|
541 |
-
|
542 |
-
for config in translation_configs:
|
543 |
-
src_lang = config["src_lang"]
|
544 |
-
tgt_lang = config["tgt_lang"]
|
545 |
-
key = model_manager._get_model_key(src_lang, tgt_lang)
|
546 |
-
translation_tasks.append((src_lang, tgt_lang, key))
|
547 |
-
|
548 |
-
for src_lang, tgt_lang, key in translation_tasks:
|
549 |
-
logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
|
550 |
-
model_manager.load_model(src_lang, tgt_lang, key)
|
551 |
-
logger.info(f"Translation model for {key} loaded successfully")
|
552 |
|
553 |
logger.info("All models loaded successfully")
|
554 |
except Exception as e:
|
@@ -625,7 +638,6 @@ async def chat(request: Request, chat_request: ChatRequest):
|
|
625 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
626 |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
|
627 |
try:
|
628 |
-
# Step 1: Translate prompt to English if needed
|
629 |
if chat_request.src_lang != "eng_Latn":
|
630 |
translated_prompt = await perform_internal_translation(
|
631 |
sentences=[chat_request.prompt],
|
@@ -638,11 +650,9 @@ async def chat(request: Request, chat_request: ChatRequest):
|
|
638 |
prompt_to_process = chat_request.prompt
|
639 |
logger.info("Prompt already in English, no translation needed")
|
640 |
|
641 |
-
# Step 2: Generate response in English
|
642 |
response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
643 |
logger.info(f"Generated English response: {response}")
|
644 |
|
645 |
-
# Step 3: Translate response to target language if needed
|
646 |
if chat_request.tgt_lang != "eng_Latn":
|
647 |
translated_response = await perform_internal_translation(
|
648 |
sentences=[response],
|
@@ -672,7 +682,6 @@ async def visual_query(
|
|
672 |
if image.size == (0, 0):
|
673 |
raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
|
674 |
|
675 |
-
# Step 1: Translate query to English if needed
|
676 |
if src_lang != "eng_Latn":
|
677 |
translated_query = await perform_internal_translation(
|
678 |
sentences=[query],
|
@@ -685,11 +694,9 @@ async def visual_query(
|
|
685 |
query_to_process = query
|
686 |
logger.info("Query already in English, no translation needed")
|
687 |
|
688 |
-
# Step 2: Generate answer in English
|
689 |
answer = await llm_manager.vision_query(image, query_to_process)
|
690 |
logger.info(f"Generated English answer: {answer}")
|
691 |
|
692 |
-
# Step 3: Translate answer to target language if needed
|
693 |
if tgt_lang != "eng_Latn":
|
694 |
translated_answer = await perform_internal_translation(
|
695 |
sentences=[answer],
|
@@ -724,7 +731,6 @@ async def chat_v2(
|
|
724 |
logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
|
725 |
|
726 |
try:
|
727 |
-
# Step 1: Handle image if provided
|
728 |
img = None
|
729 |
if image:
|
730 |
image_data = await image.read()
|
@@ -732,7 +738,6 @@ async def chat_v2(
|
|
732 |
raise HTTPException(status_code=400, detail="Uploaded image is empty")
|
733 |
img = Image.open(io.BytesIO(image_data))
|
734 |
|
735 |
-
# Step 2: Translate prompt to English if needed
|
736 |
if src_lang != "eng_Latn":
|
737 |
translated_prompt = await perform_internal_translation(
|
738 |
sentences=[prompt],
|
@@ -745,14 +750,12 @@ async def chat_v2(
|
|
745 |
prompt_to_process = prompt
|
746 |
logger.info("Prompt already in English, no translation needed")
|
747 |
|
748 |
-
# Step 3: Generate response in English
|
749 |
if img:
|
750 |
response = await llm_manager.chat_v2(img, prompt_to_process)
|
751 |
else:
|
752 |
response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
753 |
logger.info(f"Generated English response: {response}")
|
754 |
|
755 |
-
# Step 4: Translate response to target language if needed
|
756 |
if tgt_lang != "eng_Latn":
|
757 |
translated_response = await perform_internal_translation(
|
758 |
sentences=[response],
|
@@ -797,14 +800,10 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
|
|
797 |
|
798 |
@app.post("/translate", response_model=TranslationResponse)
|
799 |
async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
|
800 |
-
|
801 |
-
src_lang = request.src_lang
|
802 |
-
tgt_lang = request.tgt_lang
|
803 |
-
|
804 |
-
if not input_sentences:
|
805 |
raise HTTPException(status_code=400, detail="Input sentences are required")
|
806 |
|
807 |
-
batch = ip.preprocess_batch(
|
808 |
inputs = translate_manager.tokenizer(
|
809 |
batch,
|
810 |
truncation=True,
|
@@ -830,25 +829,9 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
|
|
830 |
clean_up_tokenization_spaces=True,
|
831 |
)
|
832 |
|
833 |
-
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
834 |
return TranslationResponse(translations=translations)
|
835 |
|
836 |
-
async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
837 |
-
try:
|
838 |
-
translate_manager = model_manager.get_model(src_lang, tgt_lang)
|
839 |
-
except ValueError as e:
|
840 |
-
logger.info(f"Model not preloaded: {str(e)}, loading now...")
|
841 |
-
key = model_manager._get_model_key(src_lang, tgt_lang)
|
842 |
-
model_manager.load_model(src_lang, tgt_lang, key)
|
843 |
-
translate_manager = model_manager.get_model(src_lang, tgt_lang)
|
844 |
-
|
845 |
-
if not translate_manager.model:
|
846 |
-
translate_manager.load()
|
847 |
-
|
848 |
-
request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
849 |
-
response = await translate(request, translate_manager)
|
850 |
-
return response.translations
|
851 |
-
|
852 |
@app.get("/v1/health")
|
853 |
async def health_check():
|
854 |
return {"status": "healthy", "model": settings.llm_model_name}
|
|
|
387 |
class TranslateManager:
|
388 |
def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
|
389 |
self.device_type = device_type
|
390 |
+
self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
|
|
|
|
|
|
|
|
|
391 |
|
392 |
+
def initialize_model(self, src_lang, tgt_lang, use_distilled):
|
393 |
+
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
394 |
+
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
|
395 |
+
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
|
396 |
+
model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
|
397 |
+
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
398 |
+
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
|
399 |
+
else:
|
400 |
+
raise ValueError("Invalid language combination: English to English translation is not supported.")
|
|
|
401 |
|
402 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
403 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
404 |
+
model_name,
|
405 |
+
trust_remote_code=True,
|
406 |
+
torch_dtype=torch.float16,
|
407 |
+
attn_implementation="flash_attention_2"
|
408 |
+
).to(self.device_type)
|
409 |
+
return tokenizer, model
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
class ModelManager:
|
412 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
413 |
+
self.models: dict[str, TranslateManager] = {}
|
414 |
self.device_type = device_type
|
415 |
self.use_distilled = use_distilled
|
416 |
self.is_lazy_loading = is_lazy_loading
|
417 |
+
# Preload all translation models
|
418 |
+
self.preload_models()
|
419 |
+
|
420 |
+
def preload_models(self):
|
421 |
+
# Define the core translation pairs to preload
|
422 |
+
translation_pairs = [
|
423 |
+
('eng_Latn', 'kan_Knda', 'eng_indic'), # English to Indic
|
424 |
+
('kan_Knda', 'eng_Latn', 'indic_eng'), # Indic to English
|
425 |
+
('kan_Knda', 'hin_Deva', 'indic_indic') # Indic to Indic
|
426 |
+
]
|
427 |
+
for src_lang, tgt_lang, key in translation_pairs:
|
428 |
+
logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang}...")
|
429 |
+
self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
|
430 |
+
logger.info(f"Translation model for {key} preloaded successfully")
|
431 |
|
432 |
+
def get_model(self, src_lang, tgt_lang) -> TranslateManager:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
434 |
+
key = 'eng_indic'
|
435 |
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
|
436 |
+
key = 'indic_eng'
|
437 |
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
438 |
+
key = 'indic_indic'
|
439 |
+
else:
|
440 |
+
raise ValueError("Invalid language combination: English to English translation is not supported.")
|
441 |
+
|
442 |
+
if key not in self.models:
|
443 |
+
raise ValueError(f"Model for {key} is not preloaded. All models should be preloaded at startup.")
|
444 |
+
return self.models[key]
|
445 |
|
446 |
# ASR Manager
|
447 |
class ASRModelManager:
|
|
|
503 |
def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
|
504 |
return model_manager.get_model(src_lang, tgt_lang)
|
505 |
|
506 |
+
# Translation Function
|
507 |
+
async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
508 |
+
translate_manager = model_manager.get_model(src_lang, tgt_lang)
|
509 |
+
if not sentences:
|
510 |
+
raise HTTPException(status_code=400, detail="Input sentences are required")
|
511 |
+
|
512 |
+
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
513 |
+
inputs = translate_manager.tokenizer(
|
514 |
+
batch,
|
515 |
+
truncation=True,
|
516 |
+
padding="longest",
|
517 |
+
return_tensors="pt",
|
518 |
+
return_attention_mask=True,
|
519 |
+
).to(translate_manager.device_type)
|
520 |
+
|
521 |
+
with torch.no_grad():
|
522 |
+
generated_tokens = translate_manager.model.generate(
|
523 |
+
**inputs,
|
524 |
+
use_cache=True,
|
525 |
+
min_length=0,
|
526 |
+
max_length=256,
|
527 |
+
num_beams=5,
|
528 |
+
num_return_sequences=1,
|
529 |
+
)
|
530 |
+
|
531 |
+
with translate_manager.tokenizer.as_target_tokenizer():
|
532 |
+
generated_tokens = translate_manager.tokenizer.batch_decode(
|
533 |
+
generated_tokens.detach().cpu().tolist(),
|
534 |
+
skip_special_tokens=True,
|
535 |
+
clean_up_tokenization_spaces=True,
|
536 |
+
)
|
537 |
+
|
538 |
+
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
539 |
+
return translations
|
540 |
+
|
541 |
# Lifespan Event Handler
|
542 |
translation_configs = []
|
543 |
|
|
|
560 |
asr_manager.load()
|
561 |
logger.info("ASR model loaded successfully")
|
562 |
|
563 |
+
# Translation models are preloaded in ModelManager constructor
|
564 |
+
logger.info("Translation models already preloaded in ModelManager initialization.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
|
566 |
logger.info("All models loaded successfully")
|
567 |
except Exception as e:
|
|
|
638 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
639 |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
|
640 |
try:
|
|
|
641 |
if chat_request.src_lang != "eng_Latn":
|
642 |
translated_prompt = await perform_internal_translation(
|
643 |
sentences=[chat_request.prompt],
|
|
|
650 |
prompt_to_process = chat_request.prompt
|
651 |
logger.info("Prompt already in English, no translation needed")
|
652 |
|
|
|
653 |
response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
654 |
logger.info(f"Generated English response: {response}")
|
655 |
|
|
|
656 |
if chat_request.tgt_lang != "eng_Latn":
|
657 |
translated_response = await perform_internal_translation(
|
658 |
sentences=[response],
|
|
|
682 |
if image.size == (0, 0):
|
683 |
raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
|
684 |
|
|
|
685 |
if src_lang != "eng_Latn":
|
686 |
translated_query = await perform_internal_translation(
|
687 |
sentences=[query],
|
|
|
694 |
query_to_process = query
|
695 |
logger.info("Query already in English, no translation needed")
|
696 |
|
|
|
697 |
answer = await llm_manager.vision_query(image, query_to_process)
|
698 |
logger.info(f"Generated English answer: {answer}")
|
699 |
|
|
|
700 |
if tgt_lang != "eng_Latn":
|
701 |
translated_answer = await perform_internal_translation(
|
702 |
sentences=[answer],
|
|
|
731 |
logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
|
732 |
|
733 |
try:
|
|
|
734 |
img = None
|
735 |
if image:
|
736 |
image_data = await image.read()
|
|
|
738 |
raise HTTPException(status_code=400, detail="Uploaded image is empty")
|
739 |
img = Image.open(io.BytesIO(image_data))
|
740 |
|
|
|
741 |
if src_lang != "eng_Latn":
|
742 |
translated_prompt = await perform_internal_translation(
|
743 |
sentences=[prompt],
|
|
|
750 |
prompt_to_process = prompt
|
751 |
logger.info("Prompt already in English, no translation needed")
|
752 |
|
|
|
753 |
if img:
|
754 |
response = await llm_manager.chat_v2(img, prompt_to_process)
|
755 |
else:
|
756 |
response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
757 |
logger.info(f"Generated English response: {response}")
|
758 |
|
|
|
759 |
if tgt_lang != "eng_Latn":
|
760 |
translated_response = await perform_internal_translation(
|
761 |
sentences=[response],
|
|
|
800 |
|
801 |
@app.post("/translate", response_model=TranslationResponse)
|
802 |
async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
|
803 |
+
if not request.sentences:
|
|
|
|
|
|
|
|
|
804 |
raise HTTPException(status_code=400, detail="Input sentences are required")
|
805 |
|
806 |
+
batch = ip.preprocess_batch(request.sentences, src_lang=request.src_lang, tgt_lang=request.tgt_lang)
|
807 |
inputs = translate_manager.tokenizer(
|
808 |
batch,
|
809 |
truncation=True,
|
|
|
829 |
clean_up_tokenization_spaces=True,
|
830 |
)
|
831 |
|
832 |
+
translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
|
833 |
return TranslationResponse(translations=translations)
|
834 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
835 |
@app.get("/v1/health")
|
836 |
async def health_check():
|
837 |
return {"status": "healthy", "model": settings.llm_model_name}
|