sachin commited on
Commit
664a539
·
1 Parent(s): 998dfd0
Files changed (1) hide show
  1. src/server/main.py +7 -8
src/server/main.py CHANGED
@@ -269,10 +269,9 @@ class TTSManager:
269
  async def load(self):
270
  if not self.model:
271
  logger.info("Loading TTS model IndicF5 asynchronously...")
272
- local_path = "/app/models/tts_model"
273
  self.model = await asyncio.to_thread(
274
  AutoModel.from_pretrained,
275
- local_path,
276
  trust_remote_code=True
277
  )
278
  self.model = self.model.to(self.device_type)
@@ -363,29 +362,29 @@ class TranslateManager:
363
  async def load(self):
364
  if not self.tokenizer or not self.model:
365
  if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
366
- local_path = "/app/models/trans_en_indic"
367
  elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
368
- local_path = "/app/models/trans_indic_en"
369
  elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
370
- local_path = "/app/models/trans_indic_indic"
371
  else:
372
  raise ValueError("Invalid language combination")
373
 
374
  self.tokenizer = await asyncio.to_thread(
375
  AutoTokenizer.from_pretrained,
376
- local_path,
377
  trust_remote_code=True
378
  )
379
  self.model = await asyncio.to_thread(
380
  AutoModelForSeq2SeqLM.from_pretrained,
381
- local_path,
382
  trust_remote_code=True,
383
  torch_dtype=torch.float16,
384
  attn_implementation="flash_attention_2"
385
  )
386
  self.model = self.model.to(self.device_type)
387
  self.model = torch.compile(self.model, mode="reduce-overhead")
388
- logger.info(f"Translation model {local_path} loaded asynchronously")
389
 
390
  class ModelManager:
391
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
 
269
  async def load(self):
270
  if not self.model:
271
  logger.info("Loading TTS model IndicF5 asynchronously...")
 
272
  self.model = await asyncio.to_thread(
273
  AutoModel.from_pretrained,
274
+ self.repo_id,
275
  trust_remote_code=True
276
  )
277
  self.model = self.model.to(self.device_type)
 
362
  async def load(self):
363
  if not self.tokenizer or not self.model:
364
  if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
365
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
366
  elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
367
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
368
  elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
369
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
370
  else:
371
  raise ValueError("Invalid language combination")
372
 
373
  self.tokenizer = await asyncio.to_thread(
374
  AutoTokenizer.from_pretrained,
375
+ model_name,
376
  trust_remote_code=True
377
  )
378
  self.model = await asyncio.to_thread(
379
  AutoModelForSeq2SeqLM.from_pretrained,
380
+ model_name,
381
  trust_remote_code=True,
382
  torch_dtype=torch.float16,
383
  attn_implementation="flash_attention_2"
384
  )
385
  self.model = self.model.to(self.device_type)
386
  self.model = torch.compile(self.model, mode="reduce-overhead")
387
+ logger.info(f"Translation model {model_name} loaded asynchronously")
388
 
389
  class ModelManager:
390
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):