sachin commited on
Commit
4d3fcd9
·
1 Parent(s): 460983d

fix-changes

Browse files
Files changed (4) hide show
  1. Dockerfile +0 -1
  2. Dockerfile.base +0 -1
  3. requirements.txt +1 -1
  4. src/server/main.py +83 -100
Dockerfile CHANGED
@@ -6,6 +6,5 @@ COPY . .
6
  ENV HF_HOME=/data/huggingface
7
  # Expose port
8
  EXPOSE 7860
9
- RUN pip install torchvision
10
  # Start the server
11
  CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
 
6
  ENV HF_HOME=/data/huggingface
7
  # Expose port
8
  EXPOSE 7860
 
9
  # Start the server
10
  CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
Dockerfile.base CHANGED
@@ -29,7 +29,6 @@ COPY requirements.txt .
29
 
30
  # Install Python dependencies
31
  RUN pip install --no-cache-dir -r requirements.txt
32
-
33
  # Set up user
34
  RUN useradd -ms /bin/bash appuser \
35
  && chown -R appuser:appuser /app
 
29
 
30
  # Install Python dependencies
31
  RUN pip install --no-cache-dir -r requirements.txt
 
32
  # Set up user
33
  RUN useradd -ms /bin/bash appuser \
34
  && chown -R appuser:appuser /app
requirements.txt CHANGED
@@ -176,7 +176,7 @@ torch==2.6.0
176
  torchaudio==2.6.0
177
  torchdiffeq==0.2.5
178
  tqdm==4.67.1
179
- transformers
180
  transformers-stream-generator==0.0.5
181
  triton==3.2.0
182
  typer==0.15.2
 
176
  torchaudio==2.6.0
177
  torchdiffeq==0.2.5
178
  tqdm==4.67.1
179
+ transformers==4.50.3
180
  transformers-stream-generator==0.0.5
181
  triton==3.2.0
182
  typer==0.15.2
src/server/main.py CHANGED
@@ -387,68 +387,61 @@ SUPPORTED_LANGUAGES = {
387
  class TranslateManager:
388
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
389
  self.device_type = device_type
390
- self.tokenizer = None
391
- self.model = None
392
- self.src_lang = src_lang
393
- self.tgt_lang = tgt_lang
394
- self.use_distilled = use_distilled
395
 
396
- def load(self):
397
- if not self.tokenizer or not self.model:
398
- if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
399
- model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
400
- elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
401
- model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
402
- elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
403
- model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
404
- else:
405
- raise ValueError("Invalid language combination")
406
 
407
- self.tokenizer = AutoTokenizer.from_pretrained(
408
- model_name,
409
- trust_remote_code=True
410
- )
411
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
412
- model_name,
413
- trust_remote_code=True,
414
- torch_dtype=torch.float16,
415
- attn_implementation="flash_attention_2"
416
- )
417
- self.model = self.model.to(self.device_type)
418
- self.model = torch.compile(self.model, mode="reduce-overhead")
419
- logger.info(f"Translation model {model_name} loaded")
420
 
421
  class ModelManager:
422
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
423
- self.models = {}
424
  self.device_type = device_type
425
  self.use_distilled = use_distilled
426
  self.is_lazy_loading = is_lazy_loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
- def load_model(self, src_lang, tgt_lang, key):
429
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
430
- translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
431
- translate_manager.load()
432
- self.models[key] = translate_manager
433
- logger.info(f"Loaded translation model for {key}")
434
-
435
- def get_model(self, src_lang, tgt_lang):
436
- key = self._get_model_key(src_lang, tgt_lang)
437
- if key not in self.models:
438
- if self.is_lazy_loading:
439
- self.load_model(src_lang, tgt_lang, key)
440
- else:
441
- raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
442
- return self.models.get(key)
443
-
444
- def _get_model_key(self, src_lang, tgt_lang):
445
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
446
- return 'eng_indic'
447
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
448
- return 'indic_eng'
449
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
450
- return 'indic_indic'
451
- raise ValueError("Invalid language combination")
 
 
 
 
 
452
 
453
  # ASR Manager
454
  class ASRModelManager:
@@ -510,6 +503,41 @@ class TranslationResponse(BaseModel):
510
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
511
  return model_manager.get_model(src_lang, tgt_lang)
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  # Lifespan Event Handler
514
  translation_configs = []
515
 
@@ -532,23 +560,8 @@ async def lifespan(app: FastAPI):
532
  asr_manager.load()
533
  logger.info("ASR model loaded successfully")
534
 
535
- # Load translation models
536
- translation_tasks = [
537
- ('eng_Latn', 'kan_Knda', 'eng_indic'),
538
- ('kan_Knda', 'eng_Latn', 'indic_eng'),
539
- ('kan_Knda', 'hin_Deva', 'indic_indic'),
540
- ]
541
-
542
- for config in translation_configs:
543
- src_lang = config["src_lang"]
544
- tgt_lang = config["tgt_lang"]
545
- key = model_manager._get_model_key(src_lang, tgt_lang)
546
- translation_tasks.append((src_lang, tgt_lang, key))
547
-
548
- for src_lang, tgt_lang, key in translation_tasks:
549
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
550
- model_manager.load_model(src_lang, tgt_lang, key)
551
- logger.info(f"Translation model for {key} loaded successfully")
552
 
553
  logger.info("All models loaded successfully")
554
  except Exception as e:
@@ -625,7 +638,6 @@ async def chat(request: Request, chat_request: ChatRequest):
625
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
626
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
627
  try:
628
- # Step 1: Translate prompt to English if needed
629
  if chat_request.src_lang != "eng_Latn":
630
  translated_prompt = await perform_internal_translation(
631
  sentences=[chat_request.prompt],
@@ -638,11 +650,9 @@ async def chat(request: Request, chat_request: ChatRequest):
638
  prompt_to_process = chat_request.prompt
639
  logger.info("Prompt already in English, no translation needed")
640
 
641
- # Step 2: Generate response in English
642
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
643
  logger.info(f"Generated English response: {response}")
644
 
645
- # Step 3: Translate response to target language if needed
646
  if chat_request.tgt_lang != "eng_Latn":
647
  translated_response = await perform_internal_translation(
648
  sentences=[response],
@@ -672,7 +682,6 @@ async def visual_query(
672
  if image.size == (0, 0):
673
  raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
674
 
675
- # Step 1: Translate query to English if needed
676
  if src_lang != "eng_Latn":
677
  translated_query = await perform_internal_translation(
678
  sentences=[query],
@@ -685,11 +694,9 @@ async def visual_query(
685
  query_to_process = query
686
  logger.info("Query already in English, no translation needed")
687
 
688
- # Step 2: Generate answer in English
689
  answer = await llm_manager.vision_query(image, query_to_process)
690
  logger.info(f"Generated English answer: {answer}")
691
 
692
- # Step 3: Translate answer to target language if needed
693
  if tgt_lang != "eng_Latn":
694
  translated_answer = await perform_internal_translation(
695
  sentences=[answer],
@@ -724,7 +731,6 @@ async def chat_v2(
724
  logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
725
 
726
  try:
727
- # Step 1: Handle image if provided
728
  img = None
729
  if image:
730
  image_data = await image.read()
@@ -732,7 +738,6 @@ async def chat_v2(
732
  raise HTTPException(status_code=400, detail="Uploaded image is empty")
733
  img = Image.open(io.BytesIO(image_data))
734
 
735
- # Step 2: Translate prompt to English if needed
736
  if src_lang != "eng_Latn":
737
  translated_prompt = await perform_internal_translation(
738
  sentences=[prompt],
@@ -745,14 +750,12 @@ async def chat_v2(
745
  prompt_to_process = prompt
746
  logger.info("Prompt already in English, no translation needed")
747
 
748
- # Step 3: Generate response in English
749
  if img:
750
  response = await llm_manager.chat_v2(img, prompt_to_process)
751
  else:
752
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
753
  logger.info(f"Generated English response: {response}")
754
 
755
- # Step 4: Translate response to target language if needed
756
  if tgt_lang != "eng_Latn":
757
  translated_response = await perform_internal_translation(
758
  sentences=[response],
@@ -797,14 +800,10 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
797
 
798
  @app.post("/translate", response_model=TranslationResponse)
799
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
800
- input_sentences = request.sentences
801
- src_lang = request.src_lang
802
- tgt_lang = request.tgt_lang
803
-
804
- if not input_sentences:
805
  raise HTTPException(status_code=400, detail="Input sentences are required")
806
 
807
- batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
808
  inputs = translate_manager.tokenizer(
809
  batch,
810
  truncation=True,
@@ -830,25 +829,9 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
830
  clean_up_tokenization_spaces=True,
831
  )
832
 
833
- translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
834
  return TranslationResponse(translations=translations)
835
 
836
- async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
837
- try:
838
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
839
- except ValueError as e:
840
- logger.info(f"Model not preloaded: {str(e)}, loading now...")
841
- key = model_manager._get_model_key(src_lang, tgt_lang)
842
- model_manager.load_model(src_lang, tgt_lang, key)
843
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
844
-
845
- if not translate_manager.model:
846
- translate_manager.load()
847
-
848
- request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
849
- response = await translate(request, translate_manager)
850
- return response.translations
851
-
852
  @app.get("/v1/health")
853
  async def health_check():
854
  return {"status": "healthy", "model": settings.llm_model_name}
 
387
  class TranslateManager:
388
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
389
  self.device_type = device_type
390
+ self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
 
 
 
 
391
 
392
+ def initialize_model(self, src_lang, tgt_lang, use_distilled):
393
+ if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
394
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
395
+ elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
396
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
397
+ elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
398
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
399
+ else:
400
+ raise ValueError("Invalid language combination: English to English translation is not supported.")
 
401
 
402
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
403
+ model = AutoModelForSeq2SeqLM.from_pretrained(
404
+ model_name,
405
+ trust_remote_code=True,
406
+ torch_dtype=torch.float16,
407
+ attn_implementation="flash_attention_2"
408
+ ).to(self.device_type)
409
+ return tokenizer, model
 
 
 
 
 
410
 
411
  class ModelManager:
412
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
413
+ self.models: dict[str, TranslateManager] = {}
414
  self.device_type = device_type
415
  self.use_distilled = use_distilled
416
  self.is_lazy_loading = is_lazy_loading
417
+ # Preload all translation models
418
+ self.preload_models()
419
+
420
+ def preload_models(self):
421
+ # Define the core translation pairs to preload
422
+ translation_pairs = [
423
+ ('eng_Latn', 'kan_Knda', 'eng_indic'), # English to Indic
424
+ ('kan_Knda', 'eng_Latn', 'indic_eng'), # Indic to English
425
+ ('kan_Knda', 'hin_Deva', 'indic_indic') # Indic to Indic
426
+ ]
427
+ for src_lang, tgt_lang, key in translation_pairs:
428
+ logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang}...")
429
+ self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
430
+ logger.info(f"Translation model for {key} preloaded successfully")
431
 
432
+ def get_model(self, src_lang, tgt_lang) -> TranslateManager:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
434
+ key = 'eng_indic'
435
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
436
+ key = 'indic_eng'
437
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
438
+ key = 'indic_indic'
439
+ else:
440
+ raise ValueError("Invalid language combination: English to English translation is not supported.")
441
+
442
+ if key not in self.models:
443
+ raise ValueError(f"Model for {key} is not preloaded. All models should be preloaded at startup.")
444
+ return self.models[key]
445
 
446
  # ASR Manager
447
  class ASRModelManager:
 
503
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
504
  return model_manager.get_model(src_lang, tgt_lang)
505
 
506
+ # Translation Function
507
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
508
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
509
+ if not sentences:
510
+ raise HTTPException(status_code=400, detail="Input sentences are required")
511
+
512
+ batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
513
+ inputs = translate_manager.tokenizer(
514
+ batch,
515
+ truncation=True,
516
+ padding="longest",
517
+ return_tensors="pt",
518
+ return_attention_mask=True,
519
+ ).to(translate_manager.device_type)
520
+
521
+ with torch.no_grad():
522
+ generated_tokens = translate_manager.model.generate(
523
+ **inputs,
524
+ use_cache=True,
525
+ min_length=0,
526
+ max_length=256,
527
+ num_beams=5,
528
+ num_return_sequences=1,
529
+ )
530
+
531
+ with translate_manager.tokenizer.as_target_tokenizer():
532
+ generated_tokens = translate_manager.tokenizer.batch_decode(
533
+ generated_tokens.detach().cpu().tolist(),
534
+ skip_special_tokens=True,
535
+ clean_up_tokenization_spaces=True,
536
+ )
537
+
538
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
539
+ return translations
540
+
541
  # Lifespan Event Handler
542
  translation_configs = []
543
 
 
560
  asr_manager.load()
561
  logger.info("ASR model loaded successfully")
562
 
563
+ # Translation models are preloaded in ModelManager constructor
564
+ logger.info("Translation models already preloaded in ModelManager initialization.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
  logger.info("All models loaded successfully")
567
  except Exception as e:
 
638
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
639
  logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
640
  try:
 
641
  if chat_request.src_lang != "eng_Latn":
642
  translated_prompt = await perform_internal_translation(
643
  sentences=[chat_request.prompt],
 
650
  prompt_to_process = chat_request.prompt
651
  logger.info("Prompt already in English, no translation needed")
652
 
 
653
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
654
  logger.info(f"Generated English response: {response}")
655
 
 
656
  if chat_request.tgt_lang != "eng_Latn":
657
  translated_response = await perform_internal_translation(
658
  sentences=[response],
 
682
  if image.size == (0, 0):
683
  raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
684
 
 
685
  if src_lang != "eng_Latn":
686
  translated_query = await perform_internal_translation(
687
  sentences=[query],
 
694
  query_to_process = query
695
  logger.info("Query already in English, no translation needed")
696
 
 
697
  answer = await llm_manager.vision_query(image, query_to_process)
698
  logger.info(f"Generated English answer: {answer}")
699
 
 
700
  if tgt_lang != "eng_Latn":
701
  translated_answer = await perform_internal_translation(
702
  sentences=[answer],
 
731
  logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
732
 
733
  try:
 
734
  img = None
735
  if image:
736
  image_data = await image.read()
 
738
  raise HTTPException(status_code=400, detail="Uploaded image is empty")
739
  img = Image.open(io.BytesIO(image_data))
740
 
 
741
  if src_lang != "eng_Latn":
742
  translated_prompt = await perform_internal_translation(
743
  sentences=[prompt],
 
750
  prompt_to_process = prompt
751
  logger.info("Prompt already in English, no translation needed")
752
 
 
753
  if img:
754
  response = await llm_manager.chat_v2(img, prompt_to_process)
755
  else:
756
  response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
757
  logger.info(f"Generated English response: {response}")
758
 
 
759
  if tgt_lang != "eng_Latn":
760
  translated_response = await perform_internal_translation(
761
  sentences=[response],
 
800
 
801
  @app.post("/translate", response_model=TranslationResponse)
802
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
803
+ if not request.sentences:
 
 
 
 
804
  raise HTTPException(status_code=400, detail="Input sentences are required")
805
 
806
+ batch = ip.preprocess_batch(request.sentences, src_lang=request.src_lang, tgt_lang=request.tgt_lang)
807
  inputs = translate_manager.tokenizer(
808
  batch,
809
  truncation=True,
 
829
  clean_up_tokenization_spaces=True,
830
  )
831
 
832
+ translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
833
  return TranslationResponse(translations=translations)
834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
835
  @app.get("/v1/health")
836
  async def health_check():
837
  return {"status": "healthy", "model": settings.llm_model_name}