import logging from typing import Iterable, Iterator, List, Union import chemdataextractor import sentencepiece as spm from chemdataextractor.data import Package from rxn.onmt_utils.internal_translation_utils import TranslationResult from rxn.onmt_utils.translator import Translator logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def download_cde_data() -> None: package = Package("models/punkt_chem-1.0.pickle") if package.local_exists(): return logger.info("Downloading the necessary ChemDataExtractor data...") package.download() logger.info("Downloading the necessary ChemDataExtractor data... Done.") def split_into_sentences(text: str) -> List[str]: paragraph = chemdataextractor.doc.Paragraph(text) return [sentence.text for sentence in paragraph.sentences] class SentencePieceTokenizer: def __init__(self, model_file: str): self.sp = spm.SentencePieceProcessor() self.sp.Load(model_file) def tokenize(self, sentence: str) -> str: tokens = self.sp.EncodeAsPieces(sentence) tokenized = " ".join(tokens) return tokenized def detokenize(self, sentence: str) -> str: tokens = sentence.split(" ") detokenized = self.sp.DecodePieces(tokens) return detokenized class TranslatorWithSentencePiece: def __init__( self, translation_model: Union[str, Iterable[str]], sentencepiece_model: str ): self.sp = SentencePieceTokenizer(sentencepiece_model) self.translator = Translator.from_model_path(translation_model) def translate(self, sentences: List[str]) -> List[str]: translations = self.translate_multiple_with_scores(sentences) return [t[0].text for t in translations] def translate_multiple_with_scores( self, sentences: List[str], n_best=1 ) -> Iterator[List[TranslationResult]]: tokenized_sentences = [self.sp.tokenize(s) for s in sentences] translations = self.translator.translate_multiple_with_scores( tokenized_sentences, n_best ) for translation_group in translations: for t in translation_group: t.text = self.sp.detokenize(t.text) yield translation_group