|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from IndicTransToolkit import IndicProcessor |
|
from fastapi import HTTPException |
|
from logging_config import logger |
|
from typing import List |
|
import torch |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
class TranslateManager: |
|
def __init__(self, src_lang, tgt_lang, device_type=DEVICE, use_distilled=True): |
|
self.device_type = device_type |
|
self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled) |
|
|
|
def initialize_model(self, src_lang, tgt_lang, use_distilled): |
|
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"): |
|
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B" |
|
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"): |
|
model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B" |
|
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"): |
|
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B" |
|
else: |
|
raise ValueError("Invalid language combination: English to English translation is not supported.") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
attn_implementation="flash_attention_2" |
|
).to(self.device_type) |
|
return tokenizer, model |
|
|
|
class ModelManager: |
|
def __init__(self, device_type=DEVICE, use_distilled=True, is_lazy_loading=True): |
|
self.models: dict[str, TranslateManager] = {} |
|
self.device_type = device_type |
|
self.use_distilled = use_distilled |
|
self.is_lazy_loading = is_lazy_loading |
|
|
|
def get_model(self, src_lang, tgt_lang) -> TranslateManager: |
|
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"): |
|
key = 'eng_indic' |
|
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"): |
|
key = 'indic_eng' |
|
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"): |
|
key = 'indic_indic' |
|
else: |
|
raise ValueError("Invalid language combination: English to English translation is not supported.") |
|
|
|
if key not in self.models: |
|
self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled) |
|
return self.models[key] |
|
|
|
ip = IndicProcessor(inference=True) |
|
model_manager = ModelManager() |
|
|
|
async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]: |
|
translate_manager = model_manager.get_model(src_lang, tgt_lang) |
|
if not sentences: |
|
raise HTTPException(status_code=400, detail="Input sentences are required") |
|
|
|
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang) |
|
inputs = translate_manager.tokenizer( |
|
batch, |
|
truncation=True, |
|
padding="longest", |
|
return_tensors="pt", |
|
return_attention_mask=True, |
|
).to(translate_manager.device_type) |
|
|
|
with torch.no_grad(): |
|
generated_tokens = translate_manager.model.generate( |
|
**inputs, |
|
use_cache=True, |
|
min_length=0, |
|
max_length=256, |
|
num_beams=5, |
|
num_return_sequences=1, |
|
) |
|
|
|
with translate_manager.tokenizer.as_target_tokenizer(): |
|
generated_tokens = translate_manager.tokenizer.batch_decode( |
|
generated_tokens.detach().cpu().tolist(), |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
|
|
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang) |
|
return translations |