sachin commited on
Commit
a9b565f
·
1 Parent(s): ba89109
Files changed (1) hide show
  1. src/server/main.py +66 -52
src/server/main.py CHANGED
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, Gemma3ForConditionalGeneration, AutoModel
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
@@ -84,8 +84,8 @@ class LLMManager:
84
  self.device = torch.device(device)
85
  self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
87
- self.is_loaded = False
88
  self.processor = None
 
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
  async def load(self):
@@ -99,12 +99,15 @@ class LLMManager:
99
  torch_dtype=self.torch_dtype
100
  )
101
  self.model.eval()
102
- self.processor = await asyncio.to_thread(AutoProcessor.from_pretrained, self.model_name)
 
 
 
103
  self.is_loaded = True
104
- logger.info(f"LLM {self.model_name} loaded on {self.device} with 4-bit quantization")
105
  except Exception as e:
106
- logger.error(f"Failed to load model: {str(e)}")
107
- raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
108
 
109
  def unload(self):
110
  if self.is_loaded:
@@ -139,8 +142,6 @@ class LLMManager:
139
  return_dict=True,
140
  return_tensors="pt"
141
  ).to(self.device, dtype=torch.bfloat16)
142
- logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
143
- logger.info(f"Decoded input: {self.processor.decode(inputs_vlm['input_ids'][0])}")
144
  except Exception as e:
145
  logger.error(f"Error in tokenization: {str(e)}")
146
  raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
@@ -190,7 +191,6 @@ class LLMManager:
190
  return_dict=True,
191
  return_tensors="pt"
192
  ).to(self.device, dtype=torch.bfloat16)
193
- logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
194
  except Exception as e:
195
  logger.error(f"Error in apply_chat_template: {str(e)}")
196
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
@@ -240,7 +240,6 @@ class LLMManager:
240
  return_dict=True,
241
  return_tensors="pt"
242
  ).to(self.device, dtype=torch.bfloat16)
243
- logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
244
  except Exception as e:
245
  logger.error(f"Error in apply_chat_template: {str(e)}")
246
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
@@ -268,14 +267,15 @@ class TTSManager:
268
  self.repo_id = "ai4bharat/IndicF5"
269
 
270
  async def load(self):
271
- logger.info("Loading TTS model IndicF5...")
272
- self.model = await asyncio.to_thread(
273
- AutoModel.from_pretrained,
274
- self.repo_id,
275
- trust_remote_code=True
276
- )
277
- self.model = self.model.to(self.device_type)
278
- logger.info("TTS model IndicF5 loaded")
 
279
 
280
  def synthesize(self, text, ref_audio_path, ref_text):
281
  if not self.model:
@@ -368,9 +368,13 @@ class TranslateManager:
368
  elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
369
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
370
  else:
371
- raise ValueError("Invalid language combination: English to English translation is not supported.")
372
 
373
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
374
  self.model = await asyncio.to_thread(
375
  AutoModelForSeq2SeqLM.from_pretrained,
376
  model_name,
@@ -380,7 +384,7 @@ class TranslateManager:
380
  )
381
  self.model = self.model.to(self.device_type)
382
  self.model = torch.compile(self.model, mode="reduce-overhead")
383
- logger.info(f"Translation model {model_name} loaded for {self.src_lang} -> {self.tgt_lang}")
384
 
385
  class ModelManager:
386
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
@@ -390,11 +394,11 @@ class ModelManager:
390
  self.is_lazy_loading = is_lazy_loading
391
 
392
  async def load_model(self, src_lang, tgt_lang, key):
393
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
394
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
395
  await translate_manager.load()
396
  self.models[key] = translate_manager
397
- logger.info(f"Loaded translation model for {key}")
398
 
399
  def get_model(self, src_lang, tgt_lang):
400
  key = self._get_model_key(src_lang, tgt_lang)
@@ -422,13 +426,15 @@ class ASRModelManager:
422
  self.model_language = {"kannada": "kn"}
423
 
424
  async def load(self):
425
- logger.info("Loading ASR model...")
426
- self.model = await asyncio.to_thread(
427
- AutoModel.from_pretrained,
428
- "ai4bharat/indic-conformer-600m-multilingual",
429
- trust_remote_code=True
430
- )
431
- logger.info("ASR model loaded")
 
 
432
 
433
  # Global Managers
434
  llm_manager = LLMManager(settings.llm_model_name)
@@ -479,25 +485,33 @@ translation_configs = []
479
  @asynccontextmanager
480
  async def lifespan(app: FastAPI):
481
  async def load_all_models():
482
- tasks = [
483
- asyncio.create_task(llm_manager.load()),
484
- asyncio.create_task(asr_manager.load()),
485
- asyncio.create_task(tts_manager.load()),
486
- asyncio.create_task(model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic')),
487
- asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
488
- asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
489
- ]
490
- for config in translation_configs:
491
- src_lang = config["src_lang"]
492
- tgt_lang = config["tgt_lang"]
493
- key = model_manager._get_model_key(src_lang, tgt_lang)
494
- tasks.append(asyncio.create_task(model_manager.load_model(src_lang, tgt_lang, key)))
495
-
496
- await asyncio.gather(*tasks)
497
- logger.info("All models loaded successfully")
 
 
 
 
 
 
 
 
498
 
499
- logger.info("Starting model loading in background...")
500
- asyncio.create_task(load_all_models())
501
  yield
502
  llm_manager.unload()
503
  logger.info("Server shutdown complete")
@@ -526,7 +540,7 @@ app.state.limiter = limiter
526
  @app.post("/audio/speech", response_class=StreamingResponse)
527
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
528
  if not tts_manager.model:
529
- raise HTTPException(status_code=503, detail="TTS model still loading, please try again later")
530
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
531
  if not request.text.strip():
532
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
@@ -591,7 +605,7 @@ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_
591
  await model_manager.load_model(src_lang, tgt_lang, key)
592
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
593
 
594
- if not translate_manager.model: # Ensure model is loaded
595
  await translate_manager.load()
596
 
597
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
@@ -814,7 +828,7 @@ async def chat_v2(
814
  @app.post("/transcribe/", response_model=TranscriptionResponse)
815
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
816
  if not asr_manager.model:
817
- raise HTTPException(status_code=503, detail="ASR model still loading, please try again later")
818
  try:
819
  wav, sr = torchaudio.load(file.file)
820
  wav = torch.mean(wav, dim=0, keepdim=True)
@@ -835,7 +849,7 @@ async def speech_to_speech(
835
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
836
  ) -> StreamingResponse:
837
  if not tts_manager.model:
838
- raise HTTPException(status_code=503, detail="TTS model still loading, please try again later")
839
  transcription = await transcribe_audio(file, language)
840
  logger.info(f"Transcribed text: {transcription.text}")
841
 
 
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
 
84
  self.device = torch.device(device)
85
  self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
 
87
  self.processor = None
88
+ self.is_loaded = False
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
  async def load(self):
 
99
  torch_dtype=self.torch_dtype
100
  )
101
  self.model.eval()
102
+ self.processor = await asyncio.to_thread(
103
+ AutoProcessor.from_pretrained,
104
+ self.model_name
105
+ )
106
  self.is_loaded = True
107
+ logger.info(f"LLM {self.model_name} loaded asynchronously on {self.device}")
108
  except Exception as e:
109
+ logger.error(f"Failed to load LLM: {str(e)}")
110
+ raise
111
 
112
  def unload(self):
113
  if self.is_loaded:
 
142
  return_dict=True,
143
  return_tensors="pt"
144
  ).to(self.device, dtype=torch.bfloat16)
 
 
145
  except Exception as e:
146
  logger.error(f"Error in tokenization: {str(e)}")
147
  raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
 
191
  return_dict=True,
192
  return_tensors="pt"
193
  ).to(self.device, dtype=torch.bfloat16)
 
194
  except Exception as e:
195
  logger.error(f"Error in apply_chat_template: {str(e)}")
196
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
240
  return_dict=True,
241
  return_tensors="pt"
242
  ).to(self.device, dtype=torch.bfloat16)
 
243
  except Exception as e:
244
  logger.error(f"Error in apply_chat_template: {str(e)}")
245
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
267
  self.repo_id = "ai4bharat/IndicF5"
268
 
269
  async def load(self):
270
+ if not self.model:
271
+ logger.info("Loading TTS model IndicF5 asynchronously...")
272
+ self.model = await asyncio.to_thread(
273
+ AutoModel.from_pretrained,
274
+ self.repo_id,
275
+ trust_remote_code=True
276
+ )
277
+ self.model = self.model.to(self.device_type)
278
+ logger.info("TTS model IndicF5 loaded asynchronously")
279
 
280
  def synthesize(self, text, ref_audio_path, ref_text):
281
  if not self.model:
 
368
  elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
369
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
370
  else:
371
+ raise ValueError("Invalid language combination")
372
 
373
+ self.tokenizer = await asyncio.to_thread(
374
+ AutoTokenizer.from_pretrained,
375
+ model_name,
376
+ trust_remote_code=True
377
+ )
378
  self.model = await asyncio.to_thread(
379
  AutoModelForSeq2SeqLM.from_pretrained,
380
  model_name,
 
384
  )
385
  self.model = self.model.to(self.device_type)
386
  self.model = torch.compile(self.model, mode="reduce-overhead")
387
+ logger.info(f"Translation model {model_name} loaded asynchronously")
388
 
389
  class ModelManager:
390
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
 
394
  self.is_lazy_loading = is_lazy_loading
395
 
396
  async def load_model(self, src_lang, tgt_lang, key):
397
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} asynchronously")
398
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
399
  await translate_manager.load()
400
  self.models[key] = translate_manager
401
+ logger.info(f"Loaded translation model for {key} asynchronously")
402
 
403
  def get_model(self, src_lang, tgt_lang):
404
  key = self._get_model_key(src_lang, tgt_lang)
 
426
  self.model_language = {"kannada": "kn"}
427
 
428
  async def load(self):
429
+ if not self.model:
430
+ logger.info("Loading ASR model asynchronously...")
431
+ self.model = await asyncio.to_thread(
432
+ AutoModel.from_pretrained,
433
+ "ai4bharat/indic-conformer-600m-multilingual",
434
+ trust_remote_code=True
435
+ )
436
+ self.model = self.model.to(self.device_type)
437
+ logger.info("ASR model loaded asynchronously")
438
 
439
  # Global Managers
440
  llm_manager = LLMManager(settings.llm_model_name)
 
485
  @asynccontextmanager
486
  async def lifespan(app: FastAPI):
487
  async def load_all_models():
488
+ try:
489
+ tasks = [
490
+ llm_manager.load(),
491
+ tts_manager.load(),
492
+ asr_manager.load(),
493
+ ]
494
+
495
+ translation_tasks = [
496
+ model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic'),
497
+ model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng'),
498
+ model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic'),
499
+ ]
500
+
501
+ for config in translation_configs:
502
+ src_lang = config["src_lang"]
503
+ tgt_lang = config["tgt_lang"]
504
+ key = model_manager._get_model_key(src_lang, tgt_lang)
505
+ translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
506
+
507
+ await asyncio.gather(*tasks, *translation_tasks)
508
+ logger.info("All models loaded successfully asynchronously")
509
+ except Exception as e:
510
+ logger.error(f"Error loading models: {str(e)}")
511
+ raise
512
 
513
+ logger.info("Starting asynchronous model loading...")
514
+ await load_all_models()
515
  yield
516
  llm_manager.unload()
517
  logger.info("Server shutdown complete")
 
540
  @app.post("/audio/speech", response_class=StreamingResponse)
541
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
542
  if not tts_manager.model:
543
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
544
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
545
  if not request.text.strip():
546
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
 
605
  await model_manager.load_model(src_lang, tgt_lang, key)
606
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
607
 
608
+ if not translate_manager.model:
609
  await translate_manager.load()
610
 
611
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
 
828
  @app.post("/transcribe/", response_model=TranscriptionResponse)
829
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
830
  if not asr_manager.model:
831
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
832
  try:
833
  wav, sr = torchaudio.load(file.file)
834
  wav = torch.mean(wav, dim=0, keepdim=True)
 
849
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
850
  ) -> StreamingResponse:
851
  if not tts_manager.model:
852
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
853
  transcription = await transcribe_audio(file, language)
854
  logger.info(f"Transcribed text: {transcription.text}")
855