sachin
commited on
Commit
·
664a539
1
Parent(s):
998dfd0
test-
Browse files- src/server/main.py +7 -8
src/server/main.py
CHANGED
@@ -269,10 +269,9 @@ class TTSManager:
|
|
269 |
async def load(self):
|
270 |
if not self.model:
|
271 |
logger.info("Loading TTS model IndicF5 asynchronously...")
|
272 |
-
local_path = "/app/models/tts_model"
|
273 |
self.model = await asyncio.to_thread(
|
274 |
AutoModel.from_pretrained,
|
275 |
-
|
276 |
trust_remote_code=True
|
277 |
)
|
278 |
self.model = self.model.to(self.device_type)
|
@@ -363,29 +362,29 @@ class TranslateManager:
|
|
363 |
async def load(self):
|
364 |
if not self.tokenizer or not self.model:
|
365 |
if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
366 |
-
|
367 |
elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
|
368 |
-
|
369 |
elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
370 |
-
|
371 |
else:
|
372 |
raise ValueError("Invalid language combination")
|
373 |
|
374 |
self.tokenizer = await asyncio.to_thread(
|
375 |
AutoTokenizer.from_pretrained,
|
376 |
-
|
377 |
trust_remote_code=True
|
378 |
)
|
379 |
self.model = await asyncio.to_thread(
|
380 |
AutoModelForSeq2SeqLM.from_pretrained,
|
381 |
-
|
382 |
trust_remote_code=True,
|
383 |
torch_dtype=torch.float16,
|
384 |
attn_implementation="flash_attention_2"
|
385 |
)
|
386 |
self.model = self.model.to(self.device_type)
|
387 |
self.model = torch.compile(self.model, mode="reduce-overhead")
|
388 |
-
logger.info(f"Translation model {
|
389 |
|
390 |
class ModelManager:
|
391 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
|
|
269 |
async def load(self):
|
270 |
if not self.model:
|
271 |
logger.info("Loading TTS model IndicF5 asynchronously...")
|
|
|
272 |
self.model = await asyncio.to_thread(
|
273 |
AutoModel.from_pretrained,
|
274 |
+
self.repo_id,
|
275 |
trust_remote_code=True
|
276 |
)
|
277 |
self.model = self.model.to(self.device_type)
|
|
|
362 |
async def load(self):
|
363 |
if not self.tokenizer or not self.model:
|
364 |
if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
365 |
+
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
|
366 |
elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
|
367 |
+
model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
|
368 |
elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
369 |
+
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
|
370 |
else:
|
371 |
raise ValueError("Invalid language combination")
|
372 |
|
373 |
self.tokenizer = await asyncio.to_thread(
|
374 |
AutoTokenizer.from_pretrained,
|
375 |
+
model_name,
|
376 |
trust_remote_code=True
|
377 |
)
|
378 |
self.model = await asyncio.to_thread(
|
379 |
AutoModelForSeq2SeqLM.from_pretrained,
|
380 |
+
model_name,
|
381 |
trust_remote_code=True,
|
382 |
torch_dtype=torch.float16,
|
383 |
attn_implementation="flash_attention_2"
|
384 |
)
|
385 |
self.model = self.model.to(self.device_type)
|
386 |
self.model = torch.compile(self.model, mode="reduce-overhead")
|
387 |
+
logger.info(f"Translation model {model_name} loaded asynchronously")
|
388 |
|
389 |
class ModelManager:
|
390 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|