# coding=utf-8 # Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for OpenAI GPT.""" from __future__ import (absolute_import, division, print_function, unicode_literals) import json import logging import os import re import sys import unicodedata from io import open import sacremoses as sm from .tokenization_utils import PreTrainedTokenizer from .tokenization_bert import BasicTokenizer logger = logging.getLogger(__name__) VOCAB_FILES_NAMES = { 'vocab_file': 'vocab.json', 'merges_file': 'merges.txt', } PRETRAINED_VOCAB_FILES_MAP = { 'vocab_file': { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json", 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json", 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json", 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json", 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json", 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-vocab.json", }, 'merges_file': { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt", 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt", 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt", 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-merges.txt", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'xlm-mlm-en-2048': 512, 'xlm-mlm-ende-1024': 512, 'xlm-mlm-enfr-1024': 512, 'xlm-mlm-enro-1024': 512, 'xlm-mlm-tlm-xnli15-1024': 512, 'xlm-mlm-xnli15-1024': 512, 'xlm-clm-enfr-1024': 512, 'xlm-clm-ende-1024': 512, 'xlm-mlm-17-1280': 512, 'xlm-mlm-100-1280': 512, } PRETRAINED_INIT_CONFIGURATION = { 'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True}, 'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True, "id2lang": { "0": "de", "1": "en"}, "lang2id": { "de": 0, "en": 1 }}, 'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True, "id2lang": { "0": "en", "1": "fr"}, "lang2id": { "en": 0, "fr": 1 }}, 'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True, "id2lang": { "0": "en", "1": "ro"}, "lang2id": { "en": 0, "ro": 1 }}, 'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, "id2lang": { "0": "ar", "1": "bg", "2": "de", "3": "el", "4": "en", "5": "es", "6": "fr", "7": "hi", "8": "ru", "9": "sw", "10": "th", "11": "tr", "12": "ur", "13": "vi", "14": "zh"}, "lang2id": { "ar": 0, "bg": 1, "de": 2, "el": 3, "en": 4, "es": 5, "fr": 6, "hi": 7, "ru": 8, "sw": 9, "th": 10, "tr": 11, "ur": 12, "vi": 13, "zh": 14 }}, 'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, "id2lang": { "0": "ar", "1": "bg", "2": "de", "3": "el", "4": "en", "5": "es", "6": "fr", "7": "hi", "8": "ru", "9": "sw", "10": "th", "11": "tr", "12": "ur", "13": "vi", "14": "zh"}, "lang2id": { "ar": 0, "bg": 1, "de": 2, "el": 3, "en": 4, "es": 5, "fr": 6, "hi": 7, "ru": 8, "sw": 9, "th": 10, "tr": 11, "ur": 12, "vi": 13, "zh": 14 }}, 'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True, "id2lang": { "0": "en", "1": "fr"}, "lang2id": { "en": 0, "fr": 1 }}, 'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True, "id2lang": { "0": "de", "1": "en"}, "lang2id": { "de": 0, "en": 1 }}, 'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False, "id2lang": { "0": "ar", "1": "de", "2": "en", "3": "es", "4": "fr", "5": "hi", "6": "it", "7": "ja", "8": "ko", "9": "nl", "10": "pl", "11": "pt", "12": "ru", "13": "sv", "14": "tr", "15": "vi", "16": "zh" }, "lang2id": { "ar": 0, "de": 1, "en": 2, "es": 3, "fr": 4, "hi": 5, "it": 6, "ja": 7, "ko": 8, "nl": 9, "pl": 10, "pt": 11, "ru": 12, "sv": 13, "tr": 14, "vi": 15, "zh": 16}}, 'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False, "id2lang": { "0": "af", "1": "als", "2": "am", "3": "an", "4": "ang", "5": "ar", "6": "arz", "7": "ast", "8": "az", "9": "bar", "10": "be", "11": "bg", "12": "bn", "13": "br", "14": "bs", "15": "ca", "16": "ceb", "17": "ckb", "18": "cs", "19": "cy", "20": "da", "21": "de", "22": "el", "23": "en", "24": "eo", "25": "es", "26": "et", "27": "eu", "28": "fa", "29": "fi", "30": "fr", "31": "fy", "32": "ga", "33": "gan", "34": "gl", "35": "gu", "36": "he", "37": "hi", "38": "hr", "39": "hu", "40": "hy", "41": "ia", "42": "id", "43": "is", "44": "it", "45": "ja", "46": "jv", "47": "ka", "48": "kk", "49": "kn", "50": "ko", "51": "ku", "52": "la", "53": "lb", "54": "lt", "55": "lv", "56": "mk", "57": "ml", "58": "mn", "59": "mr", "60": "ms", "61": "my", "62": "nds", "63": "ne", "64": "nl", "65": "nn", "66": "no", "67": "oc", "68": "pl", "69": "pt", "70": "ro", "71": "ru", "72": "scn", "73": "sco", "74": "sh", "75": "si", "76": "simple", "77": "sk", "78": "sl", "79": "sq", "80": "sr", "81": "sv", "82": "sw", "83": "ta", "84": "te", "85": "th", "86": "tl", "87": "tr", "88": "tt", "89": "uk", "90": "ur", "91": "uz", "92": "vi", "93": "war", "94": "wuu", "95": "yi", "96": "zh", "97": "zh_classical", "98": "zh_min_nan", "99": "zh_yue" }, "lang2id": { "af": 0, "als": 1, "am": 2, "an": 3, "ang": 4, "ar": 5, "arz": 6, "ast": 7, "az": 8, "bar": 9, "be": 10, "bg": 11, "bn": 12, "br": 13, "bs": 14, "ca": 15, "ceb": 16, "ckb": 17, "cs": 18, "cy": 19, "da": 20, "de": 21, "el": 22, "en": 23, "eo": 24, "es": 25, "et": 26, "eu": 27, "fa": 28, "fi": 29, "fr": 30, "fy": 31, "ga": 32, "gan": 33, "gl": 34, "gu": 35, "he": 36, "hi": 37, "hr": 38, "hu": 39, "hy": 40, "ia": 41, "id": 42, "is": 43, "it": 44, "ja": 45, "jv": 46, "ka": 47, "kk": 48, "kn": 49, "ko": 50, "ku": 51, "la": 52, "lb": 53, "lt": 54, "lv": 55, "mk": 56, "ml": 57, "mn": 58, "mr": 59, "ms": 60, "my": 61, "nds": 62, "ne": 63, "nl": 64, "nn": 65, "no": 66, "oc": 67, "pl": 68, "pt": 69, "ro": 70, "ru": 71, "scn": 72, "sco": 73, "sh": 74, "si": 75, "simple": 76, "sk": 77, "sl": 78, "sq": 79, "sr": 80, "sv": 81, "sw": 82, "ta": 83, "te": 84, "th": 85, "tl": 86, "tr": 87, "tt": 88, "uk": 89, "ur": 90, "uz": 91, "vi": 92, "war": 93, "wuu": 94, "yi": 95, "zh": 96, "zh_classical": 97, "zh_min_nan": 98, "zh_yue": 99 }}, } def get_pairs(word): """ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length strings) """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def lowercase_and_remove_accent(text): """ Lowercase and strips accents from a piece of text based on https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py """ text = ' '.join(text) text = text.lower() text = unicodedata.normalize("NFD", text) output = [] for char in text: cat = unicodedata.category(char) if cat == "Mn": continue output.append(char) return "".join(output).lower().split(' ') def replace_unicode_punct(text): ''' Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl ''' text = text.replace(',', ',') text = re.sub(r'。\s*', '. ', text) text = text.replace('、', ',') text = text.replace('”', '"') text = text.replace('“', '"') text = text.replace('∶', ':') text = text.replace(':', ':') text = text.replace('?', '?') text = text.replace('《', '"') text = text.replace('》', '"') text = text.replace(')', ')') text = text.replace('!', '!') text = text.replace('(', '(') text = text.replace(';', ';') text = text.replace('1', '"') text = text.replace('」', '"') text = text.replace('「', '"') text = text.replace('0', '0') text = text.replace('3', '3') text = text.replace('2', '2') text = text.replace('5', '5') text = text.replace('6', '6') text = text.replace('9', '9') text = text.replace('7', '7') text = text.replace('8', '8') text = text.replace('4', '4') text = re.sub(r'.\s*', '. ', text) text = text.replace('~', '~') text = text.replace('’', '\'') text = text.replace('…', '...') text = text.replace('━', '-') text = text.replace('〈', '<') text = text.replace('〉', '>') text = text.replace('【', '[') text = text.replace('】', ']') text = text.replace('%', '%') return text def remove_non_printing_char(text): ''' Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl ''' output = [] for char in text: cat = unicodedata.category(char) if cat.startswith('C'): continue output.append(char) return "".join(output) def romanian_preprocessing(text): '''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`''' # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219") text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b") # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma text = text.replace("\u0102", "A").replace("\u0103", "a") text = text.replace("\u00C2", "A").replace("\u00E2", "a") text = text.replace("\u00CE", "I").replace("\u00EE", "i") return text class XLMTokenizer(PreTrainedTokenizer): """ BPE tokenizer for XLM - Moses preprocessing & tokenization for most supported languages - Language specific tokenization for Chinese (Jieba), Japanese (KyTea) and Thai (PyThaiNLP) - (optionally) lower case & normalize all inputs text - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \ (ex: "__classify__") to a vocabulary - `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies) - `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies) - `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies) """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, merges_file, unk_token="", bos_token="", sep_token="", pad_token="", cls_token="", mask_token="", additional_special_tokens=["", "", "", "", "", "", "", "", "", ""], lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True, **kwargs): super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, **kwargs) # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = dict() # cache of sm.MosesTokenizer instance self.cache_moses_tokenizer = dict() self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja']) # True for current supported model (v1.2.0), False for XLM-17 & 100 self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent self.lang2id = lang2id self.id2lang = id2lang if lang2id is not None and id2lang is not None: assert len(lang2id) == len(id2lang) self.ja_word_tokenizer = None self.zh_word_tokenizer = None self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v:k for k,v in self.encoder.items()} merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1] merges = [tuple(merge.split()[:2]) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} def moses_punct_norm(self, text, lang): if lang not in self.cache_moses_punct_normalizer: punct_normalizer = sm.MosesPunctNormalizer(lang=lang) self.cache_moses_punct_normalizer[lang] = punct_normalizer else: punct_normalizer = self.cache_moses_punct_normalizer[lang] return punct_normalizer.normalize(text) def moses_tokenize(self, text, lang): if lang not in self.cache_moses_tokenizer: moses_tokenizer = sm.MosesTokenizer(lang=lang) self.cache_moses_tokenizer[lang] = moses_tokenizer else: moses_tokenizer = self.cache_moses_tokenizer[lang] return moses_tokenizer.tokenize(text, return_str=False, escape=False) def moses_pipeline(self, text, lang): text = replace_unicode_punct(text) text = self.moses_punct_norm(text, lang) text = remove_non_printing_char(text) return text def ja_tokenize(self, text): if self.ja_word_tokenizer is None: try: import Mykytea self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~')) except (AttributeError, ImportError) as e: logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps") logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") logger.error("2. autoreconf -i") logger.error("3. ./configure --prefix=$HOME/local") logger.error("4. make && make install") logger.error("5. pip install kytea") raise e return list(self.ja_word_tokenizer.getWS(text)) @property def vocab_size(self): return len(self.encoder) def bpe(self, token): word = tuple(token[:-1]) + (token[-1] + '',) if token in self.cache: return self.cache[token] pairs = get_pairs(word) if not pairs: return token+'' while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = ' '.join(word) if word == '\n ': word = '\n' self.cache[token] = word return word def _tokenize(self, text, lang='en', bypass_tokenizer=False): """ Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses. Details of tokenization: - [sacremoses](https://github.com/alvations/sacremoses): port of Moses - Install with `pip install sacremoses` - [pythainlp](https://github.com/PyThaiNLP/pythainlp): Thai tokenizer - Install with `pip install pythainlp` - [kytea](https://github.com/chezou/Mykytea-python): Japanese tokenizer, wrapper of [KyTea](https://github.com/neubig/kytea) - Install with the following steps: ``` git clone git@github.com:neubig/kytea.git && cd kytea autoreconf -i ./configure --prefix=$HOME/local make && make install pip install kytea ``` - [jieba](https://github.com/fxsjy/jieba): Chinese tokenizer * - Install with `pip install jieba` \* The original XLM used [Stanford Segmenter](https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip). However, the wrapper (`nltk.tokenize.stanford_segmenter`) is slow due to JVM overhead, and it will be deprecated. Jieba is a lot faster and pip-installable. Note there is some mismatch with the Stanford Segmenter. It should be fine if you fine-tune the model with Chinese supervisionself. If you want the same exact behaviour, use the original XLM [preprocessing script](https://github.com/facebookresearch/XLM/tree/master/tools) to tokenize the sentence externally, and set `bypass_tokenizer=True` to bypass the tokenizer. Args: - lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it. - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE. Returns: List of tokens. """ if lang and self.lang2id and lang not in self.lang2id: logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.") if bypass_tokenizer: text = text.split() elif lang not in self.lang_with_custom_tokenizer: text = self.moses_pipeline(text, lang=lang) # TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step if lang == 'ro': text = romanian_preprocessing(text) text = self.moses_tokenize(text, lang=lang) elif lang == 'th': text = self.moses_pipeline(text, lang=lang) try: if 'pythainlp' not in sys.modules: from pythainlp.tokenize import word_tokenize as th_word_tokenize else: th_word_tokenize = sys.modules['pythainlp'].word_tokenize except (AttributeError, ImportError) as e: logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps") logger.error("1. pip install pythainlp") raise e text = th_word_tokenize(text) elif lang == 'zh': try: if 'jieba' not in sys.modules: import jieba else: jieba = sys.modules['jieba'] except (AttributeError, ImportError) as e: logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") logger.error("1. pip install jieba") raise e text = ' '.join(jieba.cut(text)) text = self.moses_pipeline(text, lang=lang) text = text.split() elif lang == 'ja': text = self.moses_pipeline(text, lang=lang) text = self.ja_tokenize(text) else: raise ValueError('It should not reach here') if self.do_lowercase_and_remove_accent and not bypass_tokenizer: text = lowercase_and_remove_accent(text) split_tokens = [] for token in text: if token: split_tokens.extend([t for t in self.bpe(token).split(' ')]) return split_tokens def _convert_token_to_id(self, token): """ Converts a token (str/unicode) in an id using the vocab. """ return self.encoder.get(token, self.encoder.get(self.unk_token)) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (string/unicode) using the vocab.""" return self.decoder.get(index, self.unk_token) def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ out_string = ''.join(tokens).replace('', ' ').strip() return out_string def add_special_tokens_single_sentence(self, token_ids): """ Adds special tokens to a sequence for sequence classification tasks. An XLM sequence has the following format: [CLS] X [SEP] """ return [self.cls_token_id] + token_ids + [self.sep_token_id] def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): """ Adds special tokens to a sequence pair for sequence classification tasks. An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP] """ sep = [self.sep_token_id] cls = [self.cls_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" if not os.path.isdir(save_directory): logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) with open(vocab_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) index = 0 with open(merge_file, "w", encoding="utf-8") as writer: for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): if index != token_index: logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." " Please check that the tokenizer is not corrupted!".format(merge_file)) index = token_index writer.write(' '.join(bpe_tokens) + u'\n') index += 1 return vocab_file, merge_file