ALIGN-Sim / Models /MultilingualTranslationModel.py
yzm0034's picture
Upload folder using huggingface_hub
4f08d2c verified
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from tqdm import tqdm
import pandas as pd
import time
import sys
from datasets import load_dataset
from src.utils import read_data
class NLLBTranslator:
def __init__(self, model_name="facebook/nllb-200-3.3B"):
"""
Initialize the NLLB model and tokenizer for translation
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
def _get_nllb_code(self, language: str) -> str:
"""
Maps common language names to NLLB language codes.
Args:
language (str): Common language name (case-insensitive)
Returns:
str: NLLB language code or None if language not found
Examples:
>>> get_nllb_code("english")
'eng_Latn'
>>> get_nllb_code("Chinese")
'zho_Hans'
"""
language_mapping = {
# English variations
"english": "eng_Latn",
"eng": "eng_Latn",
"en": "eng_Latn",
# Hindi variations
"hindi": "hin_Deva",
"hi": "hin_Deva",
# French variations
"french": "fra_Latn",
"fr": "fra_Latn",
# Korean variations
"korean": "kor_Hang",
"ko": "kor_Hang",
# Spanish variations
"spanish": "spa_Latn",
"es": "spa_Latn",
# Chinese variations (defaulting to Simplified)
"chinese": "zho_Hans",
"chinese simplified": "zho_Hans",
"chinese traditional": "zho_Hant",
"mandarin": "zho_Hans",
"zh-cn": "zho_Hans",
# Japanese variations
"japanese": "jpn_Jpan",
"jpn": "jpn_Jpan",
"ja": "jpn_Jpan",
# German variations
"german": "deu_Latn",
"de": "deu_Latn"
}
# Convert input to lowercase for case-insensitive matching
normalized_input = language.lower().strip()
# Return the code if found, None otherwise
return language_mapping.get(normalized_input)
def add_language_code(self, name_code_dict, language, code):
# TODO: Add this fuctionality to _get_nllb_code
"""
Adds a language code to the dictionary if it is not already present.
Args:
name_code_dict (dict): Dictionary of language names to codes
language (str): Language name
code (str): Language code
Returns:
dict: Updated dictionary
"""
# Normalize the language name
normalized_language = language.lower().strip()
# Add the language code if not already present
if normalized_language not in name_code_dict:
name_code_dict[normalized_language] = code
return name_code_dict
def translate(self, text, source_lang="eng_Latn", target_lang="fra_Latn",batch_size=None):
"""
Translate text from source language to target language
Args:
text (str): Text to translate
source_lang (str): Source language code
target_lang (str): Target language code
Returns:
str: Translated text
"""
# Tokenize the input text
inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.device)
# map language names to NLLB language codes
source_lang = self._get_nllb_code(source_lang)
target_lang = self._get_nllb_code(target_lang)
# Add the source language token
forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(target_lang)
# Generate translation
translated_tokens = self.model.generate(
**inputs,
max_length=256,
num_beams=5,
temperature=0.5,
do_sample=True,
forced_bos_token_id=forced_bos_token_id,
)
# Decode the translation
if translated_tokens.shape[0] == 1: #single sentence
translation = self.tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
else:
translation = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
return translation
def main():
# Set up the model and tokenizer
print("Loading model and tokenizer...")
translator = NLLBTranslator()
# Example translations
texts = [
"Hello, how are you?",
"This is a test of the NLLB translation model.",
"Machine learning is fascinating."
]
print("\nTranslating texts from English to French:")
trt=translation = translator.translate(texts,target_lang="fr",batch_size=2)
if __name__ == "__main__":
main()