sachin
new
f641099
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from IndicTransToolkit import IndicProcessor
from fastapi import HTTPException
from logging_config import logger
from typing import List
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class TranslateManager:
def __init__(self, src_lang, tgt_lang, device_type=DEVICE, use_distilled=True):
self.device_type = device_type
self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
def initialize_model(self, src_lang, tgt_lang, use_distilled):
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
else:
raise ValueError("Invalid language combination: English to English translation is not supported.")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
).to(self.device_type)
return tokenizer, model
class ModelManager:
def __init__(self, device_type=DEVICE, use_distilled=True, is_lazy_loading=True):
self.models: dict[str, TranslateManager] = {}
self.device_type = device_type
self.use_distilled = use_distilled
self.is_lazy_loading = is_lazy_loading
def get_model(self, src_lang, tgt_lang) -> TranslateManager:
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
key = 'eng_indic'
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
key = 'indic_eng'
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
key = 'indic_indic'
else:
raise ValueError("Invalid language combination: English to English translation is not supported.")
if key not in self.models:
self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
return self.models[key]
ip = IndicProcessor(inference=True)
model_manager = ModelManager()
async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
translate_manager = model_manager.get_model(src_lang, tgt_lang)
if not sentences:
raise HTTPException(status_code=400, detail="Input sentences are required")
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
inputs = translate_manager.tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(translate_manager.device_type)
with torch.no_grad():
generated_tokens = translate_manager.model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with translate_manager.tokenizer.as_target_tokenizer():
generated_tokens = translate_manager.tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
return translations