Alain Vaucher
Explicitly download the CDE data; add logs
80ffb8e
raw
history blame
2.28 kB
import logging
from typing import Iterable, Iterator, List, Union
import chemdataextractor
import sentencepiece as spm
from chemdataextractor.data import Package
from rxn.onmt_utils.internal_translation_utils import TranslationResult
from rxn.onmt_utils.translator import Translator
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def download_cde_data() -> None:
package = Package("models/punkt_chem-1.0.pickle")
if package.local_exists():
return
logger.info("Downloading the necessary ChemDataExtractor data...")
package.download()
logger.info("Downloading the necessary ChemDataExtractor data... Done.")
def split_into_sentences(text: str) -> List[str]:
paragraph = chemdataextractor.doc.Paragraph(text)
return [sentence.text for sentence in paragraph.sentences]
class SentencePieceTokenizer:
def __init__(self, model_file: str):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(model_file)
def tokenize(self, sentence: str) -> str:
tokens = self.sp.EncodeAsPieces(sentence)
tokenized = " ".join(tokens)
return tokenized
def detokenize(self, sentence: str) -> str:
tokens = sentence.split(" ")
detokenized = self.sp.DecodePieces(tokens)
return detokenized
class TranslatorWithSentencePiece:
def __init__(
self, translation_model: Union[str, Iterable[str]], sentencepiece_model: str
):
self.sp = SentencePieceTokenizer(sentencepiece_model)
self.translator = Translator.from_model_path(translation_model)
def translate(self, sentences: List[str]) -> List[str]:
translations = self.translate_multiple_with_scores(sentences)
return [t[0].text for t in translations]
def translate_multiple_with_scores(
self, sentences: List[str], n_best=1
) -> Iterator[List[TranslationResult]]:
tokenized_sentences = [self.sp.tokenize(s) for s in sentences]
translations = self.translator.translate_multiple_with_scores(
tokenized_sentences, n_best
)
for translation_group in translations:
for t in translation_group:
t.text = self.sp.detokenize(t.text)
yield translation_group