# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines Subtokenizer class to encode and decode strings.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import re import sys import unicodedata from absl import logging import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf, tf_keras # pylint: disable=g-complex-comprehension PAD = "" PAD_ID = 0 EOS = "" EOS_ID = 1 RESERVED_TOKENS = [PAD, EOS] # Set of characters that will be used in the function _escape_token() (see func # docstring for more details). # This set is added to the alphabet list to ensure that all escaped tokens can # be encoded. _ESCAPE_CHARS = set(u"\\_u;0123456789") # Regex for the function _unescape_token(), the inverse of _escape_token(). # This is used to find "\u", "\\", and "\###;" substrings in the token. _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") _UNDEFINED_UNICODE = u"\u3013" def alphanumeric_char_set(): return set( six.unichr(i) for i in xrange(sys.maxunicode) if (unicodedata.category(six.unichr(i)).startswith("L") or unicodedata.category(six.unichr(i)).startswith("N"))) # Set contains all letter and number characters. _ALPHANUMERIC_CHAR_SET = alphanumeric_char_set() # min_count is the minimum number of times a subtoken must appear in the data # before before it is added to the vocabulary. The value is found using binary # search to obtain the target vocabulary size. _MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count _MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count class Subtokenizer(object): """Encodes and decodes strings to/from integer IDs.""" def __init__(self, vocab_file, reserved_tokens=None, master_char_set=None): """Initializes class, creating a vocab file if data_files is provided.""" logging.info("Initializing Subtokenizer from file %s.", vocab_file) if master_char_set is None: master_char_set = _ALPHANUMERIC_CHAR_SET if reserved_tokens is None: reserved_tokens = RESERVED_TOKENS self.subtoken_list = _load_vocab_file(vocab_file, reserved_tokens) self.alphabet = _generate_alphabet_dict(self.subtoken_list) self.subtoken_to_id_dict = _list_to_index_dict(self.subtoken_list) self.max_subtoken_length = 0 for subtoken in self.subtoken_list: self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken)) # Create cache to speed up subtokenization self._cache_size = 2**20 self._cache = [(None, None)] * self._cache_size self._master_char_set = master_char_set @staticmethod def init_from_files(vocab_file, files, target_vocab_size, threshold, min_count=None, file_byte_limit=1e6, reserved_tokens=None, correct_strip=True, master_char_set=None): """Create subtoken vocabulary based on files, and save vocab to file. Args: vocab_file: String name of vocab file to store subtoken vocabulary. files: List of file paths that will be used to generate vocabulary. target_vocab_size: target vocabulary size to generate. threshold: int threshold of vocabulary size to accept. min_count: int minimum count to use for generating the vocabulary. The min count is the minimum number of times a subtoken should appear in the files before it is added to the vocabulary. If set to none, this value is found using binary search. file_byte_limit: (Default 1e6) Maximum number of bytes of sample text that will be drawn from the files. reserved_tokens: List of string tokens that are guaranteed to be at the beginning of the subtoken vocabulary list. correct_strip: Whether to convert text to unicode before strip. master_char_set: the char set. Returns: Subtokenizer object """ if master_char_set is None: master_char_set = _ALPHANUMERIC_CHAR_SET if reserved_tokens is None: reserved_tokens = RESERVED_TOKENS if tf.io.gfile.exists(vocab_file): logging.info("Vocab file already exists (%s)", vocab_file) else: logging.info("Begin steps to create subtoken vocabulary...") token_counts = _count_tokens(files, file_byte_limit, correct_strip, master_char_set) alphabet = _generate_alphabet_dict(token_counts) subtoken_list = _generate_subtokens_with_target_vocab_size( token_counts, alphabet, target_vocab_size, threshold, min_count, reserved_tokens) logging.info("Generated vocabulary with %d subtokens.", len(subtoken_list)) _save_vocab_file(vocab_file, subtoken_list) return Subtokenizer(vocab_file, master_char_set=master_char_set) def encode(self, raw_string, add_eos=False): """Encodes a string into a list of int subtoken ids.""" ret = [] tokens = _split_string_to_tokens( native_to_unicode(raw_string), self._master_char_set) for token in tokens: ret.extend(self._token_to_subtoken_ids(token)) if add_eos: assert EOS in self.subtoken_list, \ "Can't append 'EOS' because it is not in list of known subtokens." ret.append(EOS_ID) return ret def _token_to_subtoken_ids(self, token): """Encode a single token into a list of subtoken ids.""" cache_location = hash(token) % self._cache_size cache_key, cache_value = self._cache[cache_location] if cache_key == token: return cache_value ret = _split_token_to_subtokens( _escape_token(token, self.alphabet), self.subtoken_to_id_dict, self.max_subtoken_length) ret = [self.subtoken_to_id_dict[subtoken_id] for subtoken_id in ret] self._cache[cache_location] = (token, ret) return ret def decode(self, subtokens): """Converts list of int subtokens ids into a string.""" if isinstance(subtokens, np.ndarray): # Note that list(subtokens) converts subtokens to a python list, but the # items remain as np.int32. This converts both the array and its items. subtokens = subtokens.tolist() if not subtokens: return "" assert isinstance(subtokens, list) and isinstance(subtokens[0], int), ( "Subtokens argument passed into decode() must be a list of integers.") return _unicode_to_native( _join_tokens_to_string( self._subtoken_ids_to_tokens(subtokens), self._master_char_set)) def _subtoken_ids_to_tokens(self, subtokens): """Convert list of int subtoken ids to a list of string tokens.""" escaped_tokens = "".join([ self.subtoken_list[s] for s in subtokens if s < len(self.subtoken_list) ]) escaped_tokens = escaped_tokens.split("_") # All tokens in the vocabulary list have been escaped (see _escape_token()) # so each token must be unescaped when decoding. ret = [] for token in escaped_tokens: if token: ret.append(_unescape_token(token)) return ret def _save_vocab_file(vocab_file, subtoken_list): """Save subtokens to file.""" with tf.io.gfile.GFile(vocab_file, mode="w") as f: for subtoken in subtoken_list: f.write("'%s'\n" % _unicode_to_native(subtoken)) def _load_vocab_file(vocab_file, reserved_tokens=None): """Load vocabulary while ensuring reserved tokens are at the top.""" if reserved_tokens is None: reserved_tokens = RESERVED_TOKENS subtoken_list = [] with tf.io.gfile.GFile(vocab_file, mode="r") as f: for line in f: subtoken = native_to_unicode(line.strip()) subtoken = subtoken[1:-1] # Remove surrounding single-quotes if subtoken in reserved_tokens: continue subtoken_list.append(native_to_unicode(subtoken)) return reserved_tokens + subtoken_list def native_to_unicode(s): """Convert string to unicode (required in Python 2).""" try: # Python 2 return s if isinstance(s, unicode) else s.decode("utf-8") except NameError: # Python 3 return s def _unicode_to_native(s): """Convert string from unicode to native format (required in Python 2).""" try: # Python 2 return s.encode("utf-8") if isinstance(s, unicode) else s except NameError: # Python 3 return s def _split_string_to_tokens(text, master_char_set): """Splits text to a list of string tokens.""" if not text: return [] ret = [] token_start = 0 # Classify each character in the input string is_master = [c in master_char_set for c in text] for pos in xrange(1, len(text)): if is_master[pos] != is_master[pos - 1]: token = text[token_start:pos] if token != u" " or token_start == 0: ret.append(token) token_start = pos final_token = text[token_start:] ret.append(final_token) return ret def _join_tokens_to_string(tokens, master_char_set): """Join a list of string tokens into a single string.""" token_is_master = [t[0] in master_char_set for t in tokens] ret = [] for i, token in enumerate(tokens): if i > 0 and token_is_master[i - 1] and token_is_master[i]: ret.append(u" ") ret.append(token) return "".join(ret) def _escape_token(token, alphabet): r"""Replace characters that aren't in the alphabet and append "_" to token. Apply three transformations to the token: 1. Replace underline character "_" with "\u", and backslash "\" with "\\". 2. Replace characters outside of the alphabet with "\###;", where ### is the character's Unicode code point. 3. Appends "_" to mark the end of a token. Args: token: unicode string to be escaped alphabet: list of all known characters Returns: escaped string """ token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] return u"".join(ret) + "_" def _unescape_token(token): r"""Replaces escaped characters in the token with their unescaped versions. Applies inverse transformations as _escape_token(): 1. Replace "\u" with "_", and "\\" with "\". 2. Replace "\###;" with the unicode character the ### refers to. Args: token: escaped string Returns: unescaped string """ def match(m): r"""Returns replacement string for matched object. Matched objects contain one of the strings that matches the regex pattern: r"\\u|\\\\|\\([0-9]+);" The strings can be '\u', '\\', or '\###;' (### is any digit number). m.group(0) refers to the entire matched string ('\u', '\\', or '\###;'). m.group(1) refers to the first parenthesized subgroup ('###'). m.group(0) exists for all match objects, while m.group(1) exists only for the string '\###;'. This function looks to see if m.group(1) exists. If it doesn't, then the matched string must be '\u' or '\\' . In this case, the corresponding replacement ('_' and '\') are returned. Note that in python, a single backslash is written as '\\', and double backslash as '\\\\'. If m.goup(1) exists, then use the integer in m.group(1) to return a unicode character. Args: m: match object Returns: String to replace matched object with. """ # Check if the matched strings are '\u' or '\\'. if m.group(1) is None: return u"_" if m.group(0) == u"\\u" else u"\\" # If m.group(1) exists, try and return unicode character. try: return six.unichr(int(m.group(1))) except (ValueError, OverflowError) as _: return _UNDEFINED_UNICODE # Use match function to replace escaped substrings in the token. return _UNESCAPE_REGEX.sub(match, token) def _count_tokens(files, file_byte_limit=1e6, correct_strip=True, master_char_set=None): """Return token counts of words in the files. Samples file_byte_limit bytes from each file, and counts the words that appear in the samples. The samples are semi-evenly distributed across the file. Args: files: List of filepaths file_byte_limit: Max number of bytes that will be read from each file. correct_strip: Whether to convert text to unicode before strip. This affects vocabulary generation for PY2. Sets correct_strip to False in PY2 to reproduce previous common public result. Sets correct_strip to True will let PY2 and PY3 get a consistent vocabulary. master_char_set: the char set. Returns: Dictionary mapping tokens to the number of times they appear in the sampled lines from the files. """ if master_char_set is None: master_char_set = _ALPHANUMERIC_CHAR_SET token_counts = collections.defaultdict(int) for filepath in files: with tf.io.gfile.GFile(filepath, mode="r") as reader: file_byte_budget = file_byte_limit counter = 0 lines_to_skip = int(reader.size() / (file_byte_budget * 2)) for line in reader: if counter < lines_to_skip: counter += 1 else: if file_byte_budget < 0: break if correct_strip: line = native_to_unicode(line) line = line.strip() file_byte_budget -= len(line) counter = 0 # Add words to token counts for token in _split_string_to_tokens( native_to_unicode(line), master_char_set): token_counts[token] += 1 return token_counts def _list_to_index_dict(lst): """Create dictionary mapping list items to their indices in the list.""" return {item: n for n, item in enumerate(lst)} def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length): """Splits a token into subtokens defined in the subtoken dict.""" ret = [] start = 0 token_len = len(token) while start < token_len: # Find the longest subtoken, so iterate backwards. for end in xrange(min(token_len, start + max_subtoken_length), start, -1): subtoken = token[start:end] if subtoken in subtoken_dict: ret.append(subtoken) start = end break else: # Did not break # If there is no possible encoding of the escaped token then one of the # characters in the token is not in the alphabet. This should be # impossible and would be indicative of a bug. raise ValueError("Was unable to split token \"%s\" into subtokens." % token) return ret def _generate_subtokens_with_target_vocab_size(token_counts, alphabet, target_size, threshold, min_count=None, reserved_tokens=None): """Generate subtoken vocabulary close to the target size.""" if reserved_tokens is None: reserved_tokens = RESERVED_TOKENS if min_count is not None: logging.info("Using min_count=%d to generate vocab with target size %d", min_count, target_size) return _generate_subtokens( token_counts, alphabet, min_count, reserved_tokens=reserved_tokens) def bisect(min_val, max_val): """Recursive function to binary search for subtoken vocabulary.""" cur_count = (min_val + max_val) // 2 logging.info("Binary search: trying min_count=%d (%d %d)", cur_count, min_val, max_val) subtoken_list = _generate_subtokens( token_counts, alphabet, cur_count, reserved_tokens=reserved_tokens) val = len(subtoken_list) logging.info("Binary search: min_count=%d resulted in %d tokens", cur_count, val) within_threshold = abs(val - target_size) < threshold if within_threshold or min_val >= max_val or cur_count < 2: return subtoken_list if val > target_size: other_subtoken_list = bisect(cur_count + 1, max_val) else: other_subtoken_list = bisect(min_val, cur_count - 1) # Return vocabulary dictionary with the closest number of tokens. other_val = len(other_subtoken_list) if abs(other_val - target_size) < abs(val - target_size): return other_subtoken_list return subtoken_list logging.info("Finding best min_count to get target size of %d", target_size) return bisect(_MIN_MIN_COUNT, _MAX_MIN_COUNT) def _generate_alphabet_dict(iterable, reserved_tokens=None): """Create set of characters that appear in any element in the iterable.""" if reserved_tokens is None: reserved_tokens = RESERVED_TOKENS alphabet = {c for token in iterable for c in token} alphabet |= {c for token in reserved_tokens for c in token} alphabet |= _ESCAPE_CHARS # Add escape characters to alphabet set. return alphabet def _count_and_gen_subtokens(token_counts, alphabet, subtoken_dict, max_subtoken_length): """Count number of times subtokens appear, and generate new subtokens. Args: token_counts: dict mapping tokens to the number of times they appear in the original files. alphabet: list of allowed characters. Used to escape the tokens, which guarantees that all tokens can be split into subtokens. subtoken_dict: dict mapping subtokens to ids. max_subtoken_length: maximum length of subtoken in subtoken_dict. Returns: A defaultdict mapping subtokens to the number of times they appear in the tokens. The dict may contain new subtokens. """ subtoken_counts = collections.defaultdict(int) for token, count in six.iteritems(token_counts): token = _escape_token(token, alphabet) subtokens = _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length) # Generate new subtokens by taking substrings from token. start = 0 for subtoken in subtokens: for end in xrange(start + 1, len(token) + 1): new_subtoken = token[start:end] subtoken_counts[new_subtoken] += count start += len(subtoken) return subtoken_counts def _filter_and_bucket_subtokens(subtoken_counts, min_count): """Return a bucketed list of subtokens that are filtered by count. Args: subtoken_counts: defaultdict mapping subtokens to their counts min_count: int count used to filter subtokens Returns: List of subtoken sets, where subtokens in set i have the same length=i. """ # Create list of buckets, where subtokens in bucket i have length i. subtoken_buckets = [] for subtoken, count in six.iteritems(subtoken_counts): if count < min_count: # Filter out subtokens that don't appear enough continue while len(subtoken_buckets) <= len(subtoken): subtoken_buckets.append(set()) subtoken_buckets[len(subtoken)].add(subtoken) return subtoken_buckets def _gen_new_subtoken_list(subtoken_counts, min_count, alphabet, reserved_tokens=None): """Generate candidate subtokens ordered by count, and new max subtoken length. Add subtokens to the candiate list in order of length (longest subtokens first). When a subtoken is added, the counts of each of its prefixes are decreased. Prefixes that don't appear much outside the subtoken are not added to the candidate list. For example: subtoken being added to candidate list: 'translate' subtoken_counts: {'translate':10, 't':40, 'tr':16, 'tra':12, ...} min_count: 5 When 'translate' is added, subtoken_counts is updated to: {'translate':0, 't':30, 'tr':6, 'tra': 2, ...} The subtoken 'tra' will not be added to the candidate list, because it appears twice (less than min_count) outside of 'translate'. Args: subtoken_counts: defaultdict mapping str subtokens to int counts min_count: int minumum count requirement for subtokens alphabet: set of characters. Each character is added to the subtoken list to guarantee that all tokens can be encoded. reserved_tokens: list of tokens that will be added to the beginning of the returned subtoken list. Returns: List of candidate subtokens in decreasing count order, and maximum subtoken length """ if reserved_tokens is None: reserved_tokens = RESERVED_TOKENS # Create a list of (count, subtoken) for each candidate subtoken. subtoken_candidates = [] # Use bucketted list to iterate through subtokens in order of length. # subtoken_buckets[i] = set(subtokens), where each subtoken has length i. subtoken_buckets = _filter_and_bucket_subtokens(subtoken_counts, min_count) max_subtoken_length = len(subtoken_buckets) - 1 # Go through the list in reverse order to consider longer subtokens first. for subtoken_len in xrange(max_subtoken_length, 0, -1): for subtoken in subtoken_buckets[subtoken_len]: count = subtoken_counts[subtoken] # Possible if this subtoken is a prefix of another token. if count < min_count: continue # Ignore alphabet/reserved tokens, which will be added manually later. if subtoken not in alphabet and subtoken not in reserved_tokens: subtoken_candidates.append((count, subtoken)) # Decrement count of the subtoken's prefixes (if a longer subtoken is # added, its prefixes lose priority to be added). for end in xrange(1, subtoken_len): subtoken_counts[subtoken[:end]] -= count # Add alphabet subtokens (guarantees that all strings are encodable). subtoken_candidates.extend((subtoken_counts.get(a, 0), a) for a in alphabet) # Order subtoken candidates by decreasing count. subtoken_list = [t for _, t in sorted(subtoken_candidates, reverse=True)] # Add reserved tokens to beginning of the list. subtoken_list = reserved_tokens + subtoken_list return subtoken_list, max_subtoken_length def _generate_subtokens(token_counts, alphabet, min_count, num_iterations=4, reserved_tokens=None): """Create a list of subtokens in decreasing order of frequency. Args: token_counts: dict mapping str tokens -> int count alphabet: set of characters min_count: int minimum number of times a subtoken must appear before it is added to the vocabulary. num_iterations: int number of iterations to generate new tokens. reserved_tokens: list of tokens that will be added to the beginning to the returned subtoken list. Returns: Sorted list of subtokens (most frequent first) """ if reserved_tokens is None: reserved_tokens = RESERVED_TOKENS # Use alphabet set to create initial list of subtokens subtoken_list = reserved_tokens + list(alphabet) max_subtoken_length = 1 # On each iteration, segment all words using the subtokens defined in # subtoken_dict, count how often the resulting subtokens appear, and update # the dictionary with subtokens w/ high enough counts. for i in xrange(num_iterations): logging.info("\tGenerating subtokens: iteration %d", i) # Generate new subtoken->id dictionary using the new subtoken list. subtoken_dict = _list_to_index_dict(subtoken_list) # Create dict mapping subtoken->count, with additional subtokens created # from substrings taken from the tokens. subtoken_counts = _count_and_gen_subtokens(token_counts, alphabet, subtoken_dict, max_subtoken_length) # Generate new list of subtokens sorted by subtoken count. subtoken_list, max_subtoken_length = _gen_new_subtoken_list( subtoken_counts, min_count, alphabet, reserved_tokens) logging.info("\tVocab size: %d", len(subtoken_list)) return subtoken_list