Spaces:
Sleeping
Sleeping
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Defines Subtokenizer class to encode and decode strings.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import collections | |
import re | |
import sys | |
import unicodedata | |
from absl import logging | |
import numpy as np | |
import six | |
from six.moves import xrange # pylint: disable=redefined-builtin | |
import tensorflow as tf, tf_keras | |
# pylint: disable=g-complex-comprehension | |
PAD = "<pad>" | |
PAD_ID = 0 | |
EOS = "<EOS>" | |
EOS_ID = 1 | |
RESERVED_TOKENS = [PAD, EOS] | |
# Set of characters that will be used in the function _escape_token() (see func | |
# docstring for more details). | |
# This set is added to the alphabet list to ensure that all escaped tokens can | |
# be encoded. | |
_ESCAPE_CHARS = set(u"\\_u;0123456789") | |
# Regex for the function _unescape_token(), the inverse of _escape_token(). | |
# This is used to find "\u", "\\", and "\###;" substrings in the token. | |
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") | |
_UNDEFINED_UNICODE = u"\u3013" | |
def alphanumeric_char_set(): | |
return set( | |
six.unichr(i) | |
for i in xrange(sys.maxunicode) | |
if (unicodedata.category(six.unichr(i)).startswith("L") or | |
unicodedata.category(six.unichr(i)).startswith("N"))) | |
# Set contains all letter and number characters. | |
_ALPHANUMERIC_CHAR_SET = alphanumeric_char_set() | |
# min_count is the minimum number of times a subtoken must appear in the data | |
# before before it is added to the vocabulary. The value is found using binary | |
# search to obtain the target vocabulary size. | |
_MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count | |
_MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count | |
class Subtokenizer(object): | |
"""Encodes and decodes strings to/from integer IDs.""" | |
def __init__(self, vocab_file, reserved_tokens=None, master_char_set=None): | |
"""Initializes class, creating a vocab file if data_files is provided.""" | |
logging.info("Initializing Subtokenizer from file %s.", vocab_file) | |
if master_char_set is None: | |
master_char_set = _ALPHANUMERIC_CHAR_SET | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
self.subtoken_list = _load_vocab_file(vocab_file, reserved_tokens) | |
self.alphabet = _generate_alphabet_dict(self.subtoken_list) | |
self.subtoken_to_id_dict = _list_to_index_dict(self.subtoken_list) | |
self.max_subtoken_length = 0 | |
for subtoken in self.subtoken_list: | |
self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken)) | |
# Create cache to speed up subtokenization | |
self._cache_size = 2**20 | |
self._cache = [(None, None)] * self._cache_size | |
self._master_char_set = master_char_set | |
def init_from_files(vocab_file, | |
files, | |
target_vocab_size, | |
threshold, | |
min_count=None, | |
file_byte_limit=1e6, | |
reserved_tokens=None, | |
correct_strip=True, | |
master_char_set=None): | |
"""Create subtoken vocabulary based on files, and save vocab to file. | |
Args: | |
vocab_file: String name of vocab file to store subtoken vocabulary. | |
files: List of file paths that will be used to generate vocabulary. | |
target_vocab_size: target vocabulary size to generate. | |
threshold: int threshold of vocabulary size to accept. | |
min_count: int minimum count to use for generating the vocabulary. The min | |
count is the minimum number of times a subtoken should appear in the | |
files before it is added to the vocabulary. If set to none, this value | |
is found using binary search. | |
file_byte_limit: (Default 1e6) Maximum number of bytes of sample text that | |
will be drawn from the files. | |
reserved_tokens: List of string tokens that are guaranteed to be at the | |
beginning of the subtoken vocabulary list. | |
correct_strip: Whether to convert text to unicode before strip. | |
master_char_set: the char set. | |
Returns: | |
Subtokenizer object | |
""" | |
if master_char_set is None: | |
master_char_set = _ALPHANUMERIC_CHAR_SET | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
if tf.io.gfile.exists(vocab_file): | |
logging.info("Vocab file already exists (%s)", vocab_file) | |
else: | |
logging.info("Begin steps to create subtoken vocabulary...") | |
token_counts = _count_tokens(files, file_byte_limit, correct_strip, | |
master_char_set) | |
alphabet = _generate_alphabet_dict(token_counts) | |
subtoken_list = _generate_subtokens_with_target_vocab_size( | |
token_counts, alphabet, target_vocab_size, threshold, min_count, | |
reserved_tokens) | |
logging.info("Generated vocabulary with %d subtokens.", | |
len(subtoken_list)) | |
_save_vocab_file(vocab_file, subtoken_list) | |
return Subtokenizer(vocab_file, master_char_set=master_char_set) | |
def encode(self, raw_string, add_eos=False): | |
"""Encodes a string into a list of int subtoken ids.""" | |
ret = [] | |
tokens = _split_string_to_tokens( | |
native_to_unicode(raw_string), self._master_char_set) | |
for token in tokens: | |
ret.extend(self._token_to_subtoken_ids(token)) | |
if add_eos: | |
assert EOS in self.subtoken_list, \ | |
"Can't append 'EOS' because it is not in list of known subtokens." | |
ret.append(EOS_ID) | |
return ret | |
def _token_to_subtoken_ids(self, token): | |
"""Encode a single token into a list of subtoken ids.""" | |
cache_location = hash(token) % self._cache_size | |
cache_key, cache_value = self._cache[cache_location] | |
if cache_key == token: | |
return cache_value | |
ret = _split_token_to_subtokens( | |
_escape_token(token, self.alphabet), self.subtoken_to_id_dict, | |
self.max_subtoken_length) | |
ret = [self.subtoken_to_id_dict[subtoken_id] for subtoken_id in ret] | |
self._cache[cache_location] = (token, ret) | |
return ret | |
def decode(self, subtokens): | |
"""Converts list of int subtokens ids into a string.""" | |
if isinstance(subtokens, np.ndarray): | |
# Note that list(subtokens) converts subtokens to a python list, but the | |
# items remain as np.int32. This converts both the array and its items. | |
subtokens = subtokens.tolist() | |
if not subtokens: | |
return "" | |
assert isinstance(subtokens, list) and isinstance(subtokens[0], int), ( | |
"Subtokens argument passed into decode() must be a list of integers.") | |
return _unicode_to_native( | |
_join_tokens_to_string( | |
self._subtoken_ids_to_tokens(subtokens), self._master_char_set)) | |
def _subtoken_ids_to_tokens(self, subtokens): | |
"""Convert list of int subtoken ids to a list of string tokens.""" | |
escaped_tokens = "".join([ | |
self.subtoken_list[s] for s in subtokens if s < len(self.subtoken_list) | |
]) | |
escaped_tokens = escaped_tokens.split("_") | |
# All tokens in the vocabulary list have been escaped (see _escape_token()) | |
# so each token must be unescaped when decoding. | |
ret = [] | |
for token in escaped_tokens: | |
if token: | |
ret.append(_unescape_token(token)) | |
return ret | |
def _save_vocab_file(vocab_file, subtoken_list): | |
"""Save subtokens to file.""" | |
with tf.io.gfile.GFile(vocab_file, mode="w") as f: | |
for subtoken in subtoken_list: | |
f.write("'%s'\n" % _unicode_to_native(subtoken)) | |
def _load_vocab_file(vocab_file, reserved_tokens=None): | |
"""Load vocabulary while ensuring reserved tokens are at the top.""" | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
subtoken_list = [] | |
with tf.io.gfile.GFile(vocab_file, mode="r") as f: | |
for line in f: | |
subtoken = native_to_unicode(line.strip()) | |
subtoken = subtoken[1:-1] # Remove surrounding single-quotes | |
if subtoken in reserved_tokens: | |
continue | |
subtoken_list.append(native_to_unicode(subtoken)) | |
return reserved_tokens + subtoken_list | |
def native_to_unicode(s): | |
"""Convert string to unicode (required in Python 2).""" | |
try: # Python 2 | |
return s if isinstance(s, unicode) else s.decode("utf-8") | |
except NameError: # Python 3 | |
return s | |
def _unicode_to_native(s): | |
"""Convert string from unicode to native format (required in Python 2).""" | |
try: # Python 2 | |
return s.encode("utf-8") if isinstance(s, unicode) else s | |
except NameError: # Python 3 | |
return s | |
def _split_string_to_tokens(text, master_char_set): | |
"""Splits text to a list of string tokens.""" | |
if not text: | |
return [] | |
ret = [] | |
token_start = 0 | |
# Classify each character in the input string | |
is_master = [c in master_char_set for c in text] | |
for pos in xrange(1, len(text)): | |
if is_master[pos] != is_master[pos - 1]: | |
token = text[token_start:pos] | |
if token != u" " or token_start == 0: | |
ret.append(token) | |
token_start = pos | |
final_token = text[token_start:] | |
ret.append(final_token) | |
return ret | |
def _join_tokens_to_string(tokens, master_char_set): | |
"""Join a list of string tokens into a single string.""" | |
token_is_master = [t[0] in master_char_set for t in tokens] | |
ret = [] | |
for i, token in enumerate(tokens): | |
if i > 0 and token_is_master[i - 1] and token_is_master[i]: | |
ret.append(u" ") | |
ret.append(token) | |
return "".join(ret) | |
def _escape_token(token, alphabet): | |
r"""Replace characters that aren't in the alphabet and append "_" to token. | |
Apply three transformations to the token: | |
1. Replace underline character "_" with "\u", and backslash "\" with "\\". | |
2. Replace characters outside of the alphabet with "\###;", where ### is the | |
character's Unicode code point. | |
3. Appends "_" to mark the end of a token. | |
Args: | |
token: unicode string to be escaped | |
alphabet: list of all known characters | |
Returns: | |
escaped string | |
""" | |
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") | |
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] | |
return u"".join(ret) + "_" | |
def _unescape_token(token): | |
r"""Replaces escaped characters in the token with their unescaped versions. | |
Applies inverse transformations as _escape_token(): | |
1. Replace "\u" with "_", and "\\" with "\". | |
2. Replace "\###;" with the unicode character the ### refers to. | |
Args: | |
token: escaped string | |
Returns: | |
unescaped string | |
""" | |
def match(m): | |
r"""Returns replacement string for matched object. | |
Matched objects contain one of the strings that matches the regex pattern: | |
r"\\u|\\\\|\\([0-9]+);" | |
The strings can be '\u', '\\', or '\###;' (### is any digit number). | |
m.group(0) refers to the entire matched string ('\u', '\\', or '\###;'). | |
m.group(1) refers to the first parenthesized subgroup ('###'). | |
m.group(0) exists for all match objects, while m.group(1) exists only for | |
the string '\###;'. | |
This function looks to see if m.group(1) exists. If it doesn't, then the | |
matched string must be '\u' or '\\' . In this case, the corresponding | |
replacement ('_' and '\') are returned. Note that in python, a single | |
backslash is written as '\\', and double backslash as '\\\\'. | |
If m.goup(1) exists, then use the integer in m.group(1) to return a | |
unicode character. | |
Args: | |
m: match object | |
Returns: | |
String to replace matched object with. | |
""" | |
# Check if the matched strings are '\u' or '\\'. | |
if m.group(1) is None: | |
return u"_" if m.group(0) == u"\\u" else u"\\" | |
# If m.group(1) exists, try and return unicode character. | |
try: | |
return six.unichr(int(m.group(1))) | |
except (ValueError, OverflowError) as _: | |
return _UNDEFINED_UNICODE | |
# Use match function to replace escaped substrings in the token. | |
return _UNESCAPE_REGEX.sub(match, token) | |
def _count_tokens(files, | |
file_byte_limit=1e6, | |
correct_strip=True, | |
master_char_set=None): | |
"""Return token counts of words in the files. | |
Samples file_byte_limit bytes from each file, and counts the words that appear | |
in the samples. The samples are semi-evenly distributed across the file. | |
Args: | |
files: List of filepaths | |
file_byte_limit: Max number of bytes that will be read from each file. | |
correct_strip: Whether to convert text to unicode before strip. This affects | |
vocabulary generation for PY2. Sets correct_strip to False in PY2 to | |
reproduce previous common public result. Sets correct_strip to True will | |
let PY2 and PY3 get a consistent vocabulary. | |
master_char_set: the char set. | |
Returns: | |
Dictionary mapping tokens to the number of times they appear in the sampled | |
lines from the files. | |
""" | |
if master_char_set is None: | |
master_char_set = _ALPHANUMERIC_CHAR_SET | |
token_counts = collections.defaultdict(int) | |
for filepath in files: | |
with tf.io.gfile.GFile(filepath, mode="r") as reader: | |
file_byte_budget = file_byte_limit | |
counter = 0 | |
lines_to_skip = int(reader.size() / (file_byte_budget * 2)) | |
for line in reader: | |
if counter < lines_to_skip: | |
counter += 1 | |
else: | |
if file_byte_budget < 0: | |
break | |
if correct_strip: | |
line = native_to_unicode(line) | |
line = line.strip() | |
file_byte_budget -= len(line) | |
counter = 0 | |
# Add words to token counts | |
for token in _split_string_to_tokens( | |
native_to_unicode(line), master_char_set): | |
token_counts[token] += 1 | |
return token_counts | |
def _list_to_index_dict(lst): | |
"""Create dictionary mapping list items to their indices in the list.""" | |
return {item: n for n, item in enumerate(lst)} | |
def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length): | |
"""Splits a token into subtokens defined in the subtoken dict.""" | |
ret = [] | |
start = 0 | |
token_len = len(token) | |
while start < token_len: | |
# Find the longest subtoken, so iterate backwards. | |
for end in xrange(min(token_len, start + max_subtoken_length), start, -1): | |
subtoken = token[start:end] | |
if subtoken in subtoken_dict: | |
ret.append(subtoken) | |
start = end | |
break | |
else: # Did not break | |
# If there is no possible encoding of the escaped token then one of the | |
# characters in the token is not in the alphabet. This should be | |
# impossible and would be indicative of a bug. | |
raise ValueError("Was unable to split token \"%s\" into subtokens." % | |
token) | |
return ret | |
def _generate_subtokens_with_target_vocab_size(token_counts, | |
alphabet, | |
target_size, | |
threshold, | |
min_count=None, | |
reserved_tokens=None): | |
"""Generate subtoken vocabulary close to the target size.""" | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
if min_count is not None: | |
logging.info("Using min_count=%d to generate vocab with target size %d", | |
min_count, target_size) | |
return _generate_subtokens( | |
token_counts, alphabet, min_count, reserved_tokens=reserved_tokens) | |
def bisect(min_val, max_val): | |
"""Recursive function to binary search for subtoken vocabulary.""" | |
cur_count = (min_val + max_val) // 2 | |
logging.info("Binary search: trying min_count=%d (%d %d)", cur_count, | |
min_val, max_val) | |
subtoken_list = _generate_subtokens( | |
token_counts, alphabet, cur_count, reserved_tokens=reserved_tokens) | |
val = len(subtoken_list) | |
logging.info("Binary search: min_count=%d resulted in %d tokens", cur_count, | |
val) | |
within_threshold = abs(val - target_size) < threshold | |
if within_threshold or min_val >= max_val or cur_count < 2: | |
return subtoken_list | |
if val > target_size: | |
other_subtoken_list = bisect(cur_count + 1, max_val) | |
else: | |
other_subtoken_list = bisect(min_val, cur_count - 1) | |
# Return vocabulary dictionary with the closest number of tokens. | |
other_val = len(other_subtoken_list) | |
if abs(other_val - target_size) < abs(val - target_size): | |
return other_subtoken_list | |
return subtoken_list | |
logging.info("Finding best min_count to get target size of %d", target_size) | |
return bisect(_MIN_MIN_COUNT, _MAX_MIN_COUNT) | |
def _generate_alphabet_dict(iterable, reserved_tokens=None): | |
"""Create set of characters that appear in any element in the iterable.""" | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
alphabet = {c for token in iterable for c in token} | |
alphabet |= {c for token in reserved_tokens for c in token} | |
alphabet |= _ESCAPE_CHARS # Add escape characters to alphabet set. | |
return alphabet | |
def _count_and_gen_subtokens(token_counts, alphabet, subtoken_dict, | |
max_subtoken_length): | |
"""Count number of times subtokens appear, and generate new subtokens. | |
Args: | |
token_counts: dict mapping tokens to the number of times they appear in the | |
original files. | |
alphabet: list of allowed characters. Used to escape the tokens, which | |
guarantees that all tokens can be split into subtokens. | |
subtoken_dict: dict mapping subtokens to ids. | |
max_subtoken_length: maximum length of subtoken in subtoken_dict. | |
Returns: | |
A defaultdict mapping subtokens to the number of times they appear in the | |
tokens. The dict may contain new subtokens. | |
""" | |
subtoken_counts = collections.defaultdict(int) | |
for token, count in six.iteritems(token_counts): | |
token = _escape_token(token, alphabet) | |
subtokens = _split_token_to_subtokens(token, subtoken_dict, | |
max_subtoken_length) | |
# Generate new subtokens by taking substrings from token. | |
start = 0 | |
for subtoken in subtokens: | |
for end in xrange(start + 1, len(token) + 1): | |
new_subtoken = token[start:end] | |
subtoken_counts[new_subtoken] += count | |
start += len(subtoken) | |
return subtoken_counts | |
def _filter_and_bucket_subtokens(subtoken_counts, min_count): | |
"""Return a bucketed list of subtokens that are filtered by count. | |
Args: | |
subtoken_counts: defaultdict mapping subtokens to their counts | |
min_count: int count used to filter subtokens | |
Returns: | |
List of subtoken sets, where subtokens in set i have the same length=i. | |
""" | |
# Create list of buckets, where subtokens in bucket i have length i. | |
subtoken_buckets = [] | |
for subtoken, count in six.iteritems(subtoken_counts): | |
if count < min_count: # Filter out subtokens that don't appear enough | |
continue | |
while len(subtoken_buckets) <= len(subtoken): | |
subtoken_buckets.append(set()) | |
subtoken_buckets[len(subtoken)].add(subtoken) | |
return subtoken_buckets | |
def _gen_new_subtoken_list(subtoken_counts, | |
min_count, | |
alphabet, | |
reserved_tokens=None): | |
"""Generate candidate subtokens ordered by count, and new max subtoken length. | |
Add subtokens to the candiate list in order of length (longest subtokens | |
first). When a subtoken is added, the counts of each of its prefixes are | |
decreased. Prefixes that don't appear much outside the subtoken are not added | |
to the candidate list. | |
For example: | |
subtoken being added to candidate list: 'translate' | |
subtoken_counts: {'translate':10, 't':40, 'tr':16, 'tra':12, ...} | |
min_count: 5 | |
When 'translate' is added, subtoken_counts is updated to: | |
{'translate':0, 't':30, 'tr':6, 'tra': 2, ...} | |
The subtoken 'tra' will not be added to the candidate list, because it appears | |
twice (less than min_count) outside of 'translate'. | |
Args: | |
subtoken_counts: defaultdict mapping str subtokens to int counts | |
min_count: int minumum count requirement for subtokens | |
alphabet: set of characters. Each character is added to the subtoken list to | |
guarantee that all tokens can be encoded. | |
reserved_tokens: list of tokens that will be added to the beginning of the | |
returned subtoken list. | |
Returns: | |
List of candidate subtokens in decreasing count order, and maximum subtoken | |
length | |
""" | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
# Create a list of (count, subtoken) for each candidate subtoken. | |
subtoken_candidates = [] | |
# Use bucketted list to iterate through subtokens in order of length. | |
# subtoken_buckets[i] = set(subtokens), where each subtoken has length i. | |
subtoken_buckets = _filter_and_bucket_subtokens(subtoken_counts, min_count) | |
max_subtoken_length = len(subtoken_buckets) - 1 | |
# Go through the list in reverse order to consider longer subtokens first. | |
for subtoken_len in xrange(max_subtoken_length, 0, -1): | |
for subtoken in subtoken_buckets[subtoken_len]: | |
count = subtoken_counts[subtoken] | |
# Possible if this subtoken is a prefix of another token. | |
if count < min_count: | |
continue | |
# Ignore alphabet/reserved tokens, which will be added manually later. | |
if subtoken not in alphabet and subtoken not in reserved_tokens: | |
subtoken_candidates.append((count, subtoken)) | |
# Decrement count of the subtoken's prefixes (if a longer subtoken is | |
# added, its prefixes lose priority to be added). | |
for end in xrange(1, subtoken_len): | |
subtoken_counts[subtoken[:end]] -= count | |
# Add alphabet subtokens (guarantees that all strings are encodable). | |
subtoken_candidates.extend((subtoken_counts.get(a, 0), a) for a in alphabet) | |
# Order subtoken candidates by decreasing count. | |
subtoken_list = [t for _, t in sorted(subtoken_candidates, reverse=True)] | |
# Add reserved tokens to beginning of the list. | |
subtoken_list = reserved_tokens + subtoken_list | |
return subtoken_list, max_subtoken_length | |
def _generate_subtokens(token_counts, | |
alphabet, | |
min_count, | |
num_iterations=4, | |
reserved_tokens=None): | |
"""Create a list of subtokens in decreasing order of frequency. | |
Args: | |
token_counts: dict mapping str tokens -> int count | |
alphabet: set of characters | |
min_count: int minimum number of times a subtoken must appear before it is | |
added to the vocabulary. | |
num_iterations: int number of iterations to generate new tokens. | |
reserved_tokens: list of tokens that will be added to the beginning to the | |
returned subtoken list. | |
Returns: | |
Sorted list of subtokens (most frequent first) | |
""" | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
# Use alphabet set to create initial list of subtokens | |
subtoken_list = reserved_tokens + list(alphabet) | |
max_subtoken_length = 1 | |
# On each iteration, segment all words using the subtokens defined in | |
# subtoken_dict, count how often the resulting subtokens appear, and update | |
# the dictionary with subtokens w/ high enough counts. | |
for i in xrange(num_iterations): | |
logging.info("\tGenerating subtokens: iteration %d", i) | |
# Generate new subtoken->id dictionary using the new subtoken list. | |
subtoken_dict = _list_to_index_dict(subtoken_list) | |
# Create dict mapping subtoken->count, with additional subtokens created | |
# from substrings taken from the tokens. | |
subtoken_counts = _count_and_gen_subtokens(token_counts, alphabet, | |
subtoken_dict, | |
max_subtoken_length) | |
# Generate new list of subtokens sorted by subtoken count. | |
subtoken_list, max_subtoken_length = _gen_new_subtoken_list( | |
subtoken_counts, min_count, alphabet, reserved_tokens) | |
logging.info("\tVocab size: %d", len(subtoken_list)) | |
return subtoken_list | |