|
import torch |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import os |
|
from tqdm import tqdm |
|
import pandas as pd |
|
import time |
|
import sys |
|
from datasets import load_dataset |
|
from src.utils import read_data |
|
|
|
class NLLBTranslator: |
|
def __init__(self, model_name="facebook/nllb-200-3.3B"): |
|
""" |
|
Initialize the NLLB model and tokenizer for translation |
|
""" |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device) |
|
|
|
def _get_nllb_code(self, language: str) -> str: |
|
""" |
|
Maps common language names to NLLB language codes. |
|
|
|
Args: |
|
language (str): Common language name (case-insensitive) |
|
|
|
Returns: |
|
str: NLLB language code or None if language not found |
|
|
|
Examples: |
|
>>> get_nllb_code("english") |
|
'eng_Latn' |
|
>>> get_nllb_code("Chinese") |
|
'zho_Hans' |
|
""" |
|
language_mapping = { |
|
|
|
"english": "eng_Latn", |
|
"eng": "eng_Latn", |
|
"en": "eng_Latn", |
|
|
|
|
|
"hindi": "hin_Deva", |
|
"hi": "hin_Deva", |
|
|
|
|
|
"french": "fra_Latn", |
|
"fr": "fra_Latn", |
|
|
|
|
|
"korean": "kor_Hang", |
|
"ko": "kor_Hang", |
|
|
|
|
|
"spanish": "spa_Latn", |
|
"es": "spa_Latn", |
|
|
|
|
|
"chinese": "zho_Hans", |
|
"chinese simplified": "zho_Hans", |
|
"chinese traditional": "zho_Hant", |
|
"mandarin": "zho_Hans", |
|
"zh-cn": "zho_Hans", |
|
|
|
|
|
"japanese": "jpn_Jpan", |
|
"jpn": "jpn_Jpan", |
|
"ja": "jpn_Jpan", |
|
|
|
|
|
"german": "deu_Latn", |
|
"de": "deu_Latn" |
|
} |
|
|
|
|
|
normalized_input = language.lower().strip() |
|
|
|
|
|
return language_mapping.get(normalized_input) |
|
|
|
def add_language_code(self, name_code_dict, language, code): |
|
|
|
|
|
""" |
|
Adds a language code to the dictionary if it is not already present. |
|
|
|
Args: |
|
name_code_dict (dict): Dictionary of language names to codes |
|
language (str): Language name |
|
code (str): Language code |
|
|
|
Returns: |
|
dict: Updated dictionary |
|
""" |
|
|
|
normalized_language = language.lower().strip() |
|
|
|
|
|
if normalized_language not in name_code_dict: |
|
name_code_dict[normalized_language] = code |
|
|
|
return name_code_dict |
|
|
|
|
|
def translate(self, text, source_lang="eng_Latn", target_lang="fra_Latn",batch_size=None): |
|
""" |
|
Translate text from source language to target language |
|
|
|
Args: |
|
text (str): Text to translate |
|
source_lang (str): Source language code |
|
target_lang (str): Target language code |
|
|
|
Returns: |
|
str: Translated text |
|
""" |
|
|
|
inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.device) |
|
|
|
|
|
source_lang = self._get_nllb_code(source_lang) |
|
target_lang = self._get_nllb_code(target_lang) |
|
|
|
forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(target_lang) |
|
|
|
|
|
translated_tokens = self.model.generate( |
|
**inputs, |
|
max_length=256, |
|
num_beams=5, |
|
temperature=0.5, |
|
do_sample=True, |
|
forced_bos_token_id=forced_bos_token_id, |
|
) |
|
|
|
|
|
if translated_tokens.shape[0] == 1: |
|
translation = self.tokenizer.decode(translated_tokens[0], skip_special_tokens=True) |
|
else: |
|
translation = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) |
|
|
|
return translation |
|
|
|
def main(): |
|
|
|
print("Loading model and tokenizer...") |
|
translator = NLLBTranslator() |
|
|
|
|
|
texts = [ |
|
"Hello, how are you?", |
|
"This is a test of the NLLB translation model.", |
|
"Machine learning is fascinating." |
|
] |
|
print("\nTranslating texts from English to French:") |
|
trt=translation = translator.translate(texts,target_lang="fr",batch_size=2) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|