pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT-related preprocessing ops (using WordPiece tokenizer)."""
from big_vision.pp import utils
from big_vision.pp.registry import Registry
import tensorflow as tf
import tensorflow_text
# Internally using
# BasicTokenizer
# https://github.com/tensorflow/text/blob/df5250d6cf1069990df4bf55154867391ab5381a/tensorflow_text/python/ops/bert_tokenizer.py#L67
# WordpieceTokenizer
# https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/wordpiece_tokenizer.py
def _create_bert_tokenizer(vocab_path):
"""Returns cls_token id and tokenizer to use in a tf.Dataset.map function."""
# Create tokenizer inside a tf.init_scope so the vocab is only loaded from
# disk once per dataset iterator (see: http://(internal link)).
# TODO: Make a local copy of vocab if creating many iterators.
with tf.init_scope():
tokenizer = tensorflow_text.BertTokenizer(
vocab_path,
token_out_type=tf.int32,
lower_case=True,
)
with tf.io.gfile.GFile(vocab_path) as f:
vocab = f.read().split("\n")
cls_token = vocab.index("[CLS]")
return cls_token, tokenizer
@Registry.register("preprocess_ops.bert_tokenize")
@utils.InKeyOutKey(indefault=None, outdefault="labels")
def get_pp_bert_tokenize(vocab_path, max_len, sample_if_multi=True):
"""Extracts tokens with tensorflow_text.BertTokenizer.
Args:
vocab_path: Path to a file containing the vocabulry for the WordPiece
tokenizer. It's the "vocab.txt" file in the zip file downloaded from
the original repo https://github.com/google-research/bert
max_len: Number of tokens after tokenization.
sample_if_multi: Whether the first text should be taken (if set to `False`),
or whether a random text should be tokenized.
Returns:
A preprocessing Op.
"""
cls_token, tokenizer = _create_bert_tokenizer(vocab_path)
def _pp_bert_tokenize(labels):
labels = tf.reshape(labels, (-1,))
labels = tf.concat([labels, [""]], axis=0)
if sample_if_multi:
num_texts = tf.maximum(tf.shape(labels)[0] - 1, 1) # Don't sample "".
txt = labels[tf.random.uniform([], 0, num_texts, dtype=tf.int32)]
else:
txt = labels[0] # Always works, since we append "" earlier on.
token_ids = tokenizer.tokenize(txt[None])
padded_token_ids, mask = tensorflow_text.pad_model_inputs(
token_ids, max_len - 1)
del mask # Recovered from zero padding in model.
count = tf.shape(padded_token_ids)[0]
padded_token_ids = tf.concat(
[tf.fill([count, 1], cls_token), padded_token_ids], axis=1)
return padded_token_ids[0]
return _pp_bert_tokenize