Spaces:
Running
Running
import hashlib | |
import os | |
import uuid | |
from typing import List, Tuple, Union, Dict | |
import regex as re | |
import sentencepiece as spm | |
from indicnlp.normalize import indic_normalize | |
from indicnlp.tokenize import indic_detokenize, indic_tokenize | |
from indicnlp.tokenize.sentence_tokenize import DELIM_PAT_NO_DANDA, sentence_split | |
from indicnlp.transliterate import unicode_transliterate | |
from mosestokenizer import MosesSentenceSplitter | |
from nltk.tokenize import sent_tokenize | |
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer | |
from tqdm import tqdm | |
from .flores_codes_map_indic import flores_codes, iso_to_flores | |
from .normalize_punctuation import punc_norm | |
from .normalize_regex_inference import EMAIL_PATTERN, normalize | |
def split_sentences(paragraph: str, lang: str) -> List[str]: | |
""" | |
Splits the input text paragraph into sentences. It uses `moses` for English and | |
`indic-nlp` for Indic languages. | |
Args: | |
paragraph (str): input text paragraph. | |
lang (str): flores language code. | |
Returns: | |
List[str] -> list of sentences. | |
""" | |
if lang == "eng_Latn": | |
with MosesSentenceSplitter(flores_codes[lang]) as splitter: | |
sents_moses = splitter([paragraph]) | |
sents_nltk = sent_tokenize(paragraph) | |
if len(sents_nltk) < len(sents_moses): | |
sents = sents_nltk | |
else: | |
sents = sents_moses | |
return [sent.replace("\xad", "") for sent in sents] | |
else: | |
return sentence_split(paragraph, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA) | |
def add_token(sent: str, src_lang: str, tgt_lang: str, delimiter: str = " ") -> str: | |
""" | |
Add special tokens indicating source and target language to the start of the input sentence. | |
The resulting string will have the format: "`{src_lang} {tgt_lang} {input_sentence}`". | |
Args: | |
sent (str): input sentence to be translated. | |
src_lang (str): flores lang code of the input sentence. | |
tgt_lang (str): flores lang code in which the input sentence will be translated. | |
delimiter (str): separator to add between language tags and input sentence (default: " "). | |
Returns: | |
str: input sentence with the special tokens added to the start. | |
""" | |
return src_lang + delimiter + tgt_lang + delimiter + sent | |
def apply_lang_tags(sents: List[str], src_lang: str, tgt_lang: str) -> List[str]: | |
""" | |
Add special tokens indicating source and target language to the start of the each input sentence. | |
Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`". | |
Args: | |
sent (str): input sentence to be translated. | |
src_lang (str): flores lang code of the input sentence. | |
tgt_lang (str): flores lang code in which the input sentence will be translated. | |
Returns: | |
List[str]: list of input sentences with the special tokens added to the start. | |
""" | |
tagged_sents = [] | |
for sent in sents: | |
tagged_sent = add_token(sent.strip(), src_lang, tgt_lang) | |
tagged_sents.append(tagged_sent) | |
return tagged_sents | |
def truncate_long_sentences( | |
sents: List[str], placeholder_entity_map_sents: List[Dict] | |
) -> Tuple[List[str], List[Dict]]: | |
""" | |
Truncates the sentences that exceed the maximum sequence length. | |
The maximum sequence for the IndicTrans2 model is limited to 256 tokens. | |
Args: | |
sents (List[str]): list of input sentences to truncate. | |
Returns: | |
Tuple[List[str], List[Dict]]: tuple containing the list of sentences with truncation applied and the updated placeholder entity maps. | |
""" | |
MAX_SEQ_LEN = 256 | |
new_sents = [] | |
placeholders = [] | |
for j, sent in enumerate(sents): | |
words = sent.split() | |
num_words = len(words) | |
if num_words > MAX_SEQ_LEN: | |
sents = [] | |
i = 0 | |
while i <= len(words): | |
sents.append(" ".join(words[i : i + MAX_SEQ_LEN])) | |
i += MAX_SEQ_LEN | |
placeholders.extend([placeholder_entity_map_sents[j]] * (len(sents))) | |
new_sents.extend(sents) | |
else: | |
placeholders.append(placeholder_entity_map_sents[j]) | |
new_sents.append(sent) | |
return new_sents, placeholders | |
class Model: | |
""" | |
Model class to run the IndicTransv2 models using python interface. | |
""" | |
def __init__( | |
self, | |
ckpt_dir: str, | |
device: str = "cuda", | |
input_lang_code_format: str = "flores", | |
model_type: str = "ctranslate2", | |
): | |
""" | |
Initialize the model class. | |
Args: | |
ckpt_dir (str): path of the model checkpoint directory. | |
device (str, optional): where to load the model (defaults: cuda). | |
""" | |
self.ckpt_dir = ckpt_dir | |
self.en_tok = MosesTokenizer(lang="en") | |
self.en_normalizer = MosesPunctNormalizer() | |
self.en_detok = MosesDetokenizer(lang="en") | |
self.xliterator = unicode_transliterate.UnicodeIndicTransliterator() | |
print("Initializing sentencepiece model for SRC and TGT") | |
self.sp_src = spm.SentencePieceProcessor( | |
model_file=os.path.join(ckpt_dir, "vocab", "model.SRC") | |
) | |
self.sp_tgt = spm.SentencePieceProcessor( | |
model_file=os.path.join(ckpt_dir, "vocab", "model.TGT") | |
) | |
self.input_lang_code_format = input_lang_code_format | |
print("Initializing model for translation") | |
# initialize the model | |
if model_type == "ctranslate2": | |
import ctranslate2 | |
self.translator = ctranslate2.Translator( | |
self.ckpt_dir, device=device | |
) # , compute_type="auto") | |
self.translate_lines = self.ctranslate2_translate_lines | |
elif model_type == "fairseq": | |
from .custom_interactive import Translator | |
self.translator = Translator( | |
data_dir=os.path.join(self.ckpt_dir, "final_bin"), | |
checkpoint_path=os.path.join(self.ckpt_dir, "model", "checkpoint_best.pt"), | |
batch_size=100, | |
) | |
self.translate_lines = self.fairseq_translate_lines | |
else: | |
raise NotImplementedError(f"Unknown model_type: {model_type}") | |
def ctranslate2_translate_lines(self, lines: List[str]) -> List[str]: | |
tokenized_sents = [x.strip().split(" ") for x in lines] | |
translations = self.translator.translate_batch( | |
tokenized_sents, | |
max_batch_size=9216, | |
batch_type="tokens", | |
max_input_length=160, | |
max_decoding_length=256, | |
beam_size=5, | |
) | |
translations = [" ".join(x.hypotheses[0]) for x in translations] | |
return translations | |
def fairseq_translate_lines(self, lines: List[str]) -> List[str]: | |
return self.translator.translate(lines) | |
def paragraphs_batch_translate__multilingual(self, batch_payloads: List[tuple]) -> List[str]: | |
""" | |
Translates a batch of input paragraphs (including pre/post processing) | |
from any language to any language. | |
Args: | |
batch_payloads (List[tuple]): batch of long input-texts to be translated, each in format: (paragraph, src_lang, tgt_lang) | |
Returns: | |
List[str]: batch of paragraph-translations in the respective languages. | |
""" | |
paragraph_id_to_sentence_range = [] | |
global__sents = [] | |
global__preprocessed_sents = [] | |
global__preprocessed_sents_placeholder_entity_map = [] | |
for i in range(len(batch_payloads)): | |
paragraph, src_lang, tgt_lang = batch_payloads[i] | |
if self.input_lang_code_format == "iso": | |
src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang] | |
batch = split_sentences(paragraph, src_lang) | |
global__sents.extend(batch) | |
preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch( | |
batch, src_lang, tgt_lang | |
) | |
global_sentence_start_index = len(global__preprocessed_sents) | |
global__preprocessed_sents.extend(preprocessed_sents) | |
global__preprocessed_sents_placeholder_entity_map.extend(placeholder_entity_map_sents) | |
paragraph_id_to_sentence_range.append( | |
(global_sentence_start_index, len(global__preprocessed_sents)) | |
) | |
translations = self.translate_lines(global__preprocessed_sents) | |
translated_paragraphs = [] | |
for paragraph_id, sentence_range in enumerate(paragraph_id_to_sentence_range): | |
tgt_lang = batch_payloads[paragraph_id][2] | |
if self.input_lang_code_format == "iso": | |
tgt_lang = iso_to_flores[tgt_lang] | |
postprocessed_sents = self.postprocess( | |
translations[sentence_range[0] : sentence_range[1]], | |
global__preprocessed_sents_placeholder_entity_map[ | |
sentence_range[0] : sentence_range[1] | |
], | |
tgt_lang, | |
) | |
translated_paragraph = " ".join(postprocessed_sents) | |
translated_paragraphs.append(translated_paragraph) | |
return translated_paragraphs | |
# translate a batch of sentences from src_lang to tgt_lang | |
def batch_translate(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]: | |
""" | |
Translates a batch of input sentences (including pre/post processing) | |
from source language to target language. | |
Args: | |
batch (List[str]): batch of input sentences to be translated. | |
src_lang (str): flores source language code. | |
tgt_lang (str): flores target language code. | |
Returns: | |
List[str]: batch of translated-sentences generated by the model. | |
""" | |
assert isinstance(batch, list) | |
if self.input_lang_code_format == "iso": | |
src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang] | |
preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch( | |
batch, src_lang, tgt_lang | |
) | |
translations = self.translate_lines(preprocessed_sents) | |
return self.postprocess(translations, placeholder_entity_map_sents, tgt_lang) | |
# translate a paragraph from src_lang to tgt_lang | |
def translate_paragraph(self, paragraph: str, src_lang: str, tgt_lang: str) -> str: | |
""" | |
Translates an input text paragraph (including pre/post processing) | |
from source language to target language. | |
Args: | |
paragraph (str): input text paragraph to be translated. | |
src_lang (str): flores source language code. | |
tgt_lang (str): flores target language code. | |
Returns: | |
str: paragraph translation generated by the model. | |
""" | |
assert isinstance(paragraph, str) | |
if self.input_lang_code_format == "iso": | |
flores_src_lang = iso_to_flores[src_lang] | |
else: | |
flores_src_lang = src_lang | |
sents = split_sentences(paragraph, flores_src_lang) | |
postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang) | |
translated_paragraph = " ".join(postprocessed_sents) | |
return translated_paragraph | |
def preprocess_batch(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]: | |
""" | |
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the | |
normalized text sequences using sentence piece tokenizer and also adds language tags. | |
Args: | |
batch (List[str]): input list of sentences to preprocess. | |
src_lang (str): flores language code of the input text sentences. | |
tgt_lang (str): flores language code of the output text sentences. | |
Returns: | |
Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary | |
mapping placeholders to their original values. | |
""" | |
preprocessed_sents, placeholder_entity_map_sents = self.preprocess(batch, lang=src_lang) | |
tokenized_sents = self.apply_spm(preprocessed_sents) | |
tokenized_sents, placeholder_entity_map_sents = truncate_long_sentences( | |
tokenized_sents, placeholder_entity_map_sents | |
) | |
tagged_sents = apply_lang_tags(tokenized_sents, src_lang, tgt_lang) | |
return tagged_sents, placeholder_entity_map_sents | |
def apply_spm(self, sents: List[str]) -> List[str]: | |
""" | |
Applies sentence piece encoding to the batch of input sentences. | |
Args: | |
sents (List[str]): batch of the input sentences. | |
Returns: | |
List[str]: batch of encoded sentences with sentence piece model | |
""" | |
return [" ".join(self.sp_src.encode(sent, out_type=str)) for sent in sents] | |
def preprocess_sent( | |
self, | |
sent: str, | |
normalizer: Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory], | |
lang: str, | |
) -> Tuple[str, Dict]: | |
""" | |
Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it. | |
Args: | |
sent (str): input text sentence to preprocess. | |
normalizer (Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory]): an object that performs normalization on the text. | |
lang (str): flores language code of the input text sentence. | |
Returns: | |
Tuple[str, Dict]: A tuple containing the preprocessed input text sentence and a corresponding dictionary | |
mapping placeholders to their original values. | |
""" | |
iso_lang = flores_codes[lang] | |
sent = punc_norm(sent, iso_lang) | |
sent, placeholder_entity_map = normalize(sent) | |
transliterate = True | |
if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]: | |
transliterate = False | |
if iso_lang == "en": | |
processed_sent = " ".join( | |
self.en_tok.tokenize(self.en_normalizer.normalize(sent.strip()), escape=False) | |
) | |
elif transliterate: | |
# transliterates from the any specific language to devanagari | |
# which is why we specify lang2_code as "hi". | |
processed_sent = self.xliterator.transliterate( | |
" ".join( | |
indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang) | |
), | |
iso_lang, | |
"hi", | |
).replace(" ् ", "्") | |
else: | |
# we only need to transliterate for joint training | |
processed_sent = " ".join( | |
indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang) | |
) | |
return processed_sent, placeholder_entity_map | |
def preprocess(self, sents: List[str], lang: str): | |
""" | |
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. | |
Args: | |
batch (List[str]): input list of sentences to preprocess. | |
lang (str): flores language code of the input text sentences. | |
Returns: | |
Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary | |
mapping placeholders to their original values. | |
""" | |
processed_sents, placeholder_entity_map_sents = [], [] | |
if lang == "eng_Latn": | |
normalizer = None | |
else: | |
normfactory = indic_normalize.IndicNormalizerFactory() | |
normalizer = normfactory.get_normalizer(flores_codes[lang]) | |
for sent in sents: | |
sent, placeholder_entity_map = self.preprocess_sent(sent, normalizer, lang) | |
processed_sents.append(sent) | |
placeholder_entity_map_sents.append(placeholder_entity_map) | |
return processed_sents, placeholder_entity_map_sents | |
def postprocess( | |
self, | |
sents: List[str], | |
placeholder_entity_map: List[Dict], | |
lang: str, | |
common_lang: str = "hin_Deva", | |
) -> List[str]: | |
""" | |
Postprocesses a batch of input sentences after the translation generations. | |
Args: | |
sents (List[str]): batch of translated sentences to postprocess. | |
placeholder_entity_map (List[Dict]): dictionary mapping placeholders to the original entity values. | |
lang (str): flores language code of the input sentences. | |
common_lang (str, optional): flores language code of the transliterated language (defaults: hin_Deva). | |
Returns: | |
List[str]: postprocessed batch of input sentences. | |
""" | |
lang_code, script_code = lang.split("_") | |
# SPM decode | |
for i in range(len(sents)): | |
# sent_tokens = sents[i].split(" ") | |
# sents[i] = self.sp_tgt.decode(sent_tokens) | |
sents[i] = sents[i].replace(" ", "").replace("▁", " ").strip() | |
# Fixes for Perso-Arabic scripts | |
# TODO: Move these normalizations inside indic-nlp-library | |
if script_code in {"Arab", "Aran"}: | |
# UrduHack adds space before punctuations. Since the model was trained without fixing this issue, let's fix it now | |
sents[i] = sents[i].replace(" ؟", "؟").replace(" ۔", "۔").replace(" ،", "،") | |
# Kashmiri bugfix for palatalization: https://github.com/AI4Bharat/IndicTrans2/issues/11 | |
sents[i] = sents[i].replace("ٮ۪", "ؠ") | |
assert len(sents) == len(placeholder_entity_map) | |
for i in range(0, len(sents)): | |
for key in placeholder_entity_map[i].keys(): | |
sents[i] = sents[i].replace(key, placeholder_entity_map[i][key]) | |
# Detokenize and transliterate to native scripts if applicable | |
postprocessed_sents = [] | |
if lang == "eng_Latn": | |
for sent in sents: | |
postprocessed_sents.append(self.en_detok.detokenize(sent.split(" "))) | |
else: | |
for sent in sents: | |
outstr = indic_detokenize.trivial_detokenize( | |
self.xliterator.transliterate( | |
sent, flores_codes[common_lang], flores_codes[lang] | |
), | |
flores_codes[lang], | |
) | |
# Oriya bug: indic-nlp-library produces ଯ଼ instead of ୟ when converting from Devanagari to Odia | |
# TODO: Find out what's the issue with unicode transliterator for Oriya and fix it | |
if lang_code == "ory": | |
outstr = outstr.replace("ଯ଼", 'ୟ') | |
postprocessed_sents.append(outstr) | |
return postprocessed_sents | |