|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BERT-related preprocessing ops (using WordPiece tokenizer).""" |
|
|
|
from big_vision.pp import utils |
|
from big_vision.pp.registry import Registry |
|
import tensorflow as tf |
|
import tensorflow_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_bert_tokenizer(vocab_path): |
|
"""Returns cls_token id and tokenizer to use in a tf.Dataset.map function.""" |
|
|
|
|
|
|
|
with tf.init_scope(): |
|
tokenizer = tensorflow_text.BertTokenizer( |
|
vocab_path, |
|
token_out_type=tf.int32, |
|
lower_case=True, |
|
) |
|
|
|
with tf.io.gfile.GFile(vocab_path) as f: |
|
vocab = f.read().split("\n") |
|
cls_token = vocab.index("[CLS]") |
|
|
|
return cls_token, tokenizer |
|
|
|
|
|
@Registry.register("preprocess_ops.bert_tokenize") |
|
@utils.InKeyOutKey(indefault=None, outdefault="labels") |
|
def get_pp_bert_tokenize(vocab_path, max_len, sample_if_multi=True): |
|
"""Extracts tokens with tensorflow_text.BertTokenizer. |
|
|
|
Args: |
|
vocab_path: Path to a file containing the vocabulry for the WordPiece |
|
tokenizer. It's the "vocab.txt" file in the zip file downloaded from |
|
the original repo https://github.com/google-research/bert |
|
max_len: Number of tokens after tokenization. |
|
sample_if_multi: Whether the first text should be taken (if set to `False`), |
|
or whether a random text should be tokenized. |
|
|
|
Returns: |
|
A preprocessing Op. |
|
""" |
|
|
|
cls_token, tokenizer = _create_bert_tokenizer(vocab_path) |
|
|
|
def _pp_bert_tokenize(labels): |
|
|
|
labels = tf.reshape(labels, (-1,)) |
|
labels = tf.concat([labels, [""]], axis=0) |
|
if sample_if_multi: |
|
num_texts = tf.maximum(tf.shape(labels)[0] - 1, 1) |
|
txt = labels[tf.random.uniform([], 0, num_texts, dtype=tf.int32)] |
|
else: |
|
txt = labels[0] |
|
|
|
token_ids = tokenizer.tokenize(txt[None]) |
|
padded_token_ids, mask = tensorflow_text.pad_model_inputs( |
|
token_ids, max_len - 1) |
|
del mask |
|
count = tf.shape(padded_token_ids)[0] |
|
padded_token_ids = tf.concat( |
|
[tf.fill([count, 1], cls_token), padded_token_ids], axis=1) |
|
return padded_token_ids[0] |
|
|
|
return _pp_bert_tokenize |
|
|