File size: 12,517 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text-centric preprocessing ops.
All preprocessing ops should return a data processing functors. A data
is represented as a dictionary of (TF) tensors. The functors output a modified
dictionary.
A commonly used key for the tokenized output is "labels".
"""
import functools
import importlib
from absl import logging
from big_vision.datasets.imagenet import class_names as imagenet_class_names
from big_vision.pp import ops_general
from big_vision.pp import tokenizer as bv_tok
from big_vision.pp import utils
from big_vision.pp.registry import Registry
import tensorflow as tf
from tensorflow.io import gfile
import sentencepiece
SPProcessor = sentencepiece.SentencePieceProcessor
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
import sentencepiece.sentencepiece_model_pb2
del os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION']
SPModelProto = sentencepiece.sentencepiece_model_pb2.ModelProto
# TODO: b/lbeyer - softly introduce and move to new tokenizer API.
KNOWN_TOKENIZERS = {
"mc4": # used in multilingual models (mT5, PaLI), vocab_size=250_000
"gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
"cc_all": # vocab_size=32_000
"gs://t5-data/vocabs/cc_all.32000/sentencepiece.model",
"c4_en": # vocab_size=32_000
"gs://t5-data/vocabs/cc_en.32000/sentencepiece.model",
"t5": # same as cc_all, but with 100 extra dummy tokens used by T5 models
"gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model",
"mt5": # same as mc4, but with 100 extra dummy tokens used by T5 models
"gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
}
def create_tokenizer(model="c4_en", add_eos=True, add_bos=False):
"""Creates a tokenizer which can be used in tfds."""
logging.info("Creating tokenizer: %s", model)
with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f:
model = f.read()
# Lazy import of tensorflow_text so it is an optional dependency for
# the users of this file.
import tensorflow_text
return tensorflow_text.SentencepieceTokenizer(
model=model, add_eos=add_eos, add_bos=add_bos
)
def tokenize(input_text, tokenizer, max_len, *, pad_value, force_eos,
multi_text=False):
"""Tokenizes string, and adds `pad_value` if longer than `max_len`."""
def pad(tokens):
# Truncate/pad to max_len.
if force_eos:
tokens = tf.cond(
tf.shape(tokens)[0] >= max_len,
lambda: tf.concat(
# For too long, cut them off, but do keep the final EOS token.
[tokens[:max_len - 1], tokens[-1:]], axis=0),
lambda: tf.pad(
tokens, [(0, max_len - tf.shape(tokens)[0])],
constant_values=pad_value),
)
else:
tokens = tokens[:max_len]
tokens = tf.pad(
tokens, [(0, max_len - tf.shape(tokens)[0])],
constant_values=pad_value)
tokens.set_shape([max_len])
return tokens
tokens = tokenizer.tokenize(input_text)
if multi_text:
tokens = tokens.to_tensor(pad_value) # tf.RaggedTensor to tf.Tensor
tokens = tf.reshape(tokens, [-1, tf.shape(tokens)[-1]])
tokens = tf.map_fn(pad, tokens) # `map_fn` only maps on axis 0
final_shape = tf.concat([tf.shape(input_text), [max_len]], axis=0)
return tf.reshape(tokens, final_shape)
else:
return pad(tokens)
@Registry.register("preprocess_ops.tokenize")
@utils.InKeyOutKey(indefault=None, outdefault="labels")
def get_pp_tokenize(
max_len,
eos,
model="c4_en",
lower=True,
sample_if_multi=True,
pad_value="<pad>",
add_bos=False
):
"""Tokenizes a text.
Let's assume max_len=3 and id("</s>")=1, id("a")=2, then we have
1. `eos="none", pad_value=0`:
- "a" -> [2, 0, 0]
- "aa" -> [2, 2, 0]
- "aaa" -> [2, 2, 2]
2. `eos="yes", pad_value=0`:
- "a" -> [2, 1, 0]
- "aa" -> [2, 2, 1]
- "aaa" -> [2, 2, 2]
This is usually used with generative models that need to learn when to
properly predict a "</s>" (when the sentence is finished) and when to
abstain (when the sentence is truncated).
3. `eos="sticky", pad_value=0`:
- "a" -> [2, 1, 0]
- "aa" -> [2, 2, 1]
- "aaa" -> [2, 2, 1]
4. `eos="sticky", pad_value=1`:
- "a" -> [2, 1, 1]
- "aa" -> [2, 2, 1]
- "aaa" -> [2, 2, 1]
This is traditionally used with contrastive models that use the last token
for embeddings, similarly to "cls" tokens in BERT-style models.
Args:
max_len: maximum length of the tokenized text.
eos: Whether to add an "</s>" (end of sentence) token and whether to keep it
when the sequence is longer than `max_len - 1`. See examples above for
details. Valid values: "none", "yes", "sticky".
model: a path to the pretrained sentencepiece model.
lower: lowercase the text before tokenizing.
sample_if_multi: If there's more than one, randomly pick one if this is
True; otherwise pick all texts and keep the input's batch shape in result.
pad_value: which token to pad the sequence with. If a string (for example
`"<pad>"`), tokenize it and use its first token. Note that there is no
guarantee to have any padding at the end of the sentence, if the sentence
is longer than `max_len`.
add_bos: adds beginning of sentence symbol.
Returns:
an op that outputs tokenized text.
"""
if eos not in ("yes", "none", "sticky"):
raise ValueError(f"Invalid value for eos: '{eos}'.")
tokenizer = create_tokenizer(model, add_eos=eos != "none", add_bos=add_bos)
if isinstance(pad_value, str):
pad_value = tokenizer.string_to_id(pad_value)
def _pp_tokenize(txt):
if sample_if_multi and tf.convert_to_tensor(txt).ndim:
# TODO: I wish this code-path could die.
logging.warning("sample_if_multi is deprecated and will be removed."
"Call `choice` (and maybe `setdefault`) instead.")
txt = ops_general.get_choice(key="t")(
ops_general.get_setdefault("t", "")({"t": txt}))["t"]
if lower:
txt = tf.strings.lower(txt) if sample_if_multi else tf.map_fn(
tf.strings.lower, txt)
return tokenize(
txt,
tokenizer,
max_len,
pad_value=pad_value,
force_eos=eos == "sticky",
multi_text=not sample_if_multi)
return _pp_tokenize
@Registry.register("preprocess_ops.coco_captions")
def get_coco_captions(outkey="captions"):
"""Extracts coco's captions from nested dict."""
def _pp_coco_captions(data):
data[outkey] = data["captions"]["text"]
return data
return _pp_coco_captions
@Registry.register("preprocess_ops.clip_i1k_label_names")
@utils.InKeyOutKey(indefault="label", outdefault="labels")
def get_pp_clip_i1k_label_names():
"""Convert i1k label numbers to strings, using CLIP's class names."""
def _pp_imagenet_labels(label):
return tf.gather(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, label)
return _pp_imagenet_labels
@Registry.register("preprocess_ops.lower")
@utils.InKeyOutKey(indefault="text", outdefault="text")
def get_lower():
"""Lowercases text feature."""
def _pp_lower(text):
return tf.strings.lower(text)
return _pp_lower
def _add_pieces(model_bytes, extra_pieces):
"""Adds extra pieces to sentencpiece model specified by `model_bytes`."""
model = SPProcessor()
model.LoadFromSerializedProto(model_bytes)
unk_idx = model.PieceToId("<unk>")
assert model.IdToPiece(unk_idx) == "<unk>", model.IdToPiece(unk_idx)
model_proto = SPModelProto.FromString(model_bytes)
idx_to_updated_piece = {}
for piece in extra_pieces:
# The SentencePieceModel proto stores whitespaces as the special
# character '▁'. We perform the conversion here.
piece = piece.replace(" ", "▁")
spiece = model_proto.SentencePiece(
piece=piece,
# We set the highest score to force priority on user defined tokens.
score=0.0,
type=model_proto.SentencePiece().Type.USER_DEFINED,
)
existing_idx = model.PieceToId(piece)
if (existing_idx != unk_idx) ^ (piece == "<unk>"):
idx_to_updated_piece[existing_idx] = spiece
logging.info("Updating token at idx %d: %s", existing_idx, spiece.piece)
else:
model_proto.pieces.append(spiece)
# Replace duplicated pieces with updated ones.
updated_pieces = [
idx_to_updated_piece.get(i, piece)
for i, piece in enumerate(model_proto.pieces)
]
del model_proto.pieces[:]
model_proto.pieces.extend(updated_pieces)
return model_proto.SerializeToString()
def _iterable(x):
if isinstance(x, tf.RaggedTensor):
return True
if getattr(x, "ndim", 0) > 1: # np, jnp
return True
if isinstance(x, (list, tuple)) and not isinstance(x[0], (int, float)):
return True
return False
@Registry.register("tokenizers.sp")
class SentencepieceTokenizer(bv_tok.Tokenizer):
"""Wraps a `tftext.SentencepieceTokenizer`.
If you plan to use this tokenizer, please familiarize yourself with the test
cases first. This is likely to save you a lot of troubles down the road, trust
me!
"""
def __init__(self, model, tokensets=()):
with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f:
model_bytes = f.read()
extras = bv_tok.get_extra_tokens(tokensets)
model_bytes = _add_pieces(model_bytes, extras)
self._tok_sp = SPProcessor()
self._tok_sp.LoadFromSerializedProto(model_bytes)
self.extras = {self._tok_sp.PieceToId(x): x for x in extras}
def to_int(self, text, *, bos=False, eos=False):
def _single(s):
return (
([self.bos_token] if bos else []) +
self._tok_sp.EncodeAsIds(s) +
([self.eos_token] if eos else [])
)
if isinstance(text, str):
return _single(text)
return type(text)([_single(s) for s in text])
def to_str(self, tokens, *, stop_at_eos=True):
def _single(toks):
toks = [int(t) for t in toks] # We really need this for DecodeIds.
if stop_at_eos:
try: # The SentencePiece strips eos, but does not stop at it, so we do.
toks = toks[:toks.index(self.eos_token)]
except ValueError: # No eos token found, nothing to do.
pass
return self._tok_sp.DecodeIds(toks)
if _iterable(tokens):
return [_single(toks) for toks in tokens]
return _single(tokens)
def _check_known(self, piece):
if (id_ := self._tok_sp.PieceToId(piece)) == self._tok_sp.unk_id():
logging.error("Piece '%s' is not known (unk=%s)!", piece, id_)
return id_
def to_piece(self, idx):
return self._tok_sp.IdToPiece(int(idx))
@property
def pad_token(self):
return self._tok_sp.pad_id()
@property
def eos_token(self):
return self._tok_sp.eos_id()
@property
def bos_token(self):
return self._tok_sp.bos_id()
@property
def vocab_size(self):
return self._tok_sp.GetPieceSize()
# For the _tf_op variants, we need a lot of wrapping boilerplate.
def to_int_tf_op(self, text, *, bos=False, eos=False):
text = tf.convert_to_tensor(text)
if text.ndim == 0:
def fn(txt):
string = txt.numpy().decode()
return tf.constant(self.to_int(string, bos=bos, eos=eos), tf.int32)
return tf.py_function(fn, [text], tf.int32)
else:
def fn(txt):
strings = [s.decode() for s in txt.numpy().tolist()]
toks = self.to_int(strings, bos=bos, eos=eos)
return tf.ragged.constant(toks)
out_type = tf.RaggedTensorSpec([tf.shape(text)[0], None], tf.int32)
return tf.py_function(fn, [text], Tout=out_type)
def to_str_tf_op(self, tokens, *, stop_at_eos=True):
def single(t):
fn = functools.partial(self.to_str, stop_at_eos=stop_at_eos)
return tf.numpy_function(fn, [t], tf.string, stateful=False)
if _iterable(tokens):
return tf.map_fn(single, tokens, tf.string)
return single(tokens)
|