|
import logging |
|
from typing import Iterable, Iterator, List, Union |
|
|
|
import chemdataextractor |
|
import sentencepiece as spm |
|
from chemdataextractor.data import Package |
|
from rxn.onmt_utils.internal_translation_utils import TranslationResult |
|
from rxn.onmt_utils.translator import Translator |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
|
|
def download_cde_data() -> None: |
|
package = Package("models/punkt_chem-1.0.pickle") |
|
if package.local_exists(): |
|
return |
|
|
|
logger.info("Downloading the necessary ChemDataExtractor data...") |
|
package.download() |
|
logger.info("Downloading the necessary ChemDataExtractor data... Done.") |
|
|
|
|
|
def split_into_sentences(text: str) -> List[str]: |
|
paragraph = chemdataextractor.doc.Paragraph(text) |
|
return [sentence.text for sentence in paragraph.sentences] |
|
|
|
|
|
class SentencePieceTokenizer: |
|
def __init__(self, model_file: str): |
|
self.sp = spm.SentencePieceProcessor() |
|
self.sp.Load(model_file) |
|
|
|
def tokenize(self, sentence: str) -> str: |
|
tokens = self.sp.EncodeAsPieces(sentence) |
|
tokenized = " ".join(tokens) |
|
return tokenized |
|
|
|
def detokenize(self, sentence: str) -> str: |
|
tokens = sentence.split(" ") |
|
detokenized = self.sp.DecodePieces(tokens) |
|
return detokenized |
|
|
|
|
|
class TranslatorWithSentencePiece: |
|
def __init__( |
|
self, translation_model: Union[str, Iterable[str]], sentencepiece_model: str |
|
): |
|
self.sp = SentencePieceTokenizer(sentencepiece_model) |
|
self.translator = Translator.from_model_path(translation_model) |
|
|
|
def translate(self, sentences: List[str]) -> List[str]: |
|
translations = self.translate_multiple_with_scores(sentences) |
|
return [t[0].text for t in translations] |
|
|
|
def translate_multiple_with_scores( |
|
self, sentences: List[str], n_best=1 |
|
) -> Iterator[List[TranslationResult]]: |
|
tokenized_sentences = [self.sp.tokenize(s) for s in sentences] |
|
|
|
translations = self.translator.translate_multiple_with_scores( |
|
tokenized_sentences, n_best |
|
) |
|
|
|
for translation_group in translations: |
|
for t in translation_group: |
|
t.text = self.sp.detokenize(t.text) |
|
yield translation_group |
|
|