pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""pp ops."""
import functools
import string
from big_vision.pp import ops_text
from big_vision.pp import utils
from big_vision.pp.registry import Registry
import big_vision.pp.tokenizer as bv_tok
import numpy as np
import tensorflow as tf
@Registry.register('tokenizers.gemma')
def get_tokenizer_gemma(
tokensets=(),
model='gs://big_vision/gemma_tokenizer.model',
):
# See (internal link) for colab playground.
return ops_text.SentencepieceTokenizer(model=model, tokensets=tokensets)
@functools.cache
def tokenize_constant(model, text, bos='no', eos='no', length=None):
"""Tokenize a constant string, with memoization."""
assert eos in ('no', 'yes', 'sticky')
assert bos in ('no', 'yes')
tokenizer = bv_tok.get_tokenizer(model)
tokens = tokenizer.to_int(
text, bos=bos == 'yes', eos=eos in ('yes', 'sticky'))
if length is None:
return tokens
if len(tokens) > length:
if eos == 'sticky':
return np.r_[tokens[:length-1], tokens[-1]]
else:
return tokens[:length]
else:
return np.pad(tokens, [(0, length - len(tokens))],
constant_values=tokenizer.pad_token)
@Registry.register('preprocess_ops.tolen')
@utils.InKeyOutKey(indefault=None, outdefault=None, with_data=True)
def get_tolen(length, *, sticky_end=False, pad_value=None, pad_key=None):
"""Gets token to a fixed length."""
def _tolen(x, data):
if not length:
return x
xlen = tf.shape(x)[0]
if sticky_end:
trunc_fn = lambda: tf.concat([x[:length - 1], x[-1:]], axis=0)
else:
trunc_fn = lambda: x[:length]
# Potentially get the pad value from a data key (to be tokenizer agnostic).
pad_value_ = pad_value
if pad_key:
pad_value_ = data[pad_key]
# If coming from a previous tokenization op, it's probably 1D; take first.
if getattr(pad_value_, 'ndim', 0) == 1:
pad_value_ = pad_value_[0]
assert pad_value_ is not None, 'Need either pad_value or pad_key.'
pad_fn = lambda: tf.pad(x, [(0, length - xlen)], constant_values=pad_value_)
out = tf.cond(xlen >= length, trunc_fn, pad_fn)
out.set_shape([length])
return out
return _tolen
@Registry.register('preprocess_ops.tok')
def get_tokenize(model, length=None, *, bos='no', eos='no',
text=None, key=None, inkey=None, outkey=None):
"""Tokenizes and optionally truncates/pads a string."""
assert eos in ('no', 'yes', 'sticky')
assert bos in ('no', 'yes')
outkey_ = outkey or key
inkey_ = inkey or key
if text is not None:
assert inkey is None, 'Either inkey or text, not both.'
tokens = tokenize_constant(model, text, bos=bos, eos=eos, length=length)
def _pp_tokenize_text(data):
data[outkey_] = tokens
return data
return _pp_tokenize_text
tokenizer = bv_tok.get_tokenizer(model)
def _pp_tokenize(data):
assert getattr(data[inkey_], 'ndim', 0) == 0, (
f'Can only tokenize single string ({inkey_}, {data[inkey_].ndim}-D)')
toks = tokenizer.to_int_tf_op(
data[inkey_], bos=bos == 'yes', eos=eos in ('yes', 'sticky'))
tolen = get_tolen(
length, sticky_end=eos == 'sticky',
pad_value=bv_tok.get_tokenizer(model).pad_token,
key='tmp',
)
toks = tolen({'tmp': toks})['tmp']
data[outkey_] = toks
return data
return _pp_tokenize
@Registry.register('preprocess_ops.masked_concat')
def get_masked_concat(keys, outkey='text', **masks):
assert all(len(keys) == len(m) for m in masks.values()), (keys, masks)
def _masked_concat(data):
data[outkey] = tf.concat([data[k] for k in keys], axis=0)
for mask_name, mask_vals in masks.items():
m = [tf.fill(tf.shape(data[k]), v) for k, v in zip(keys, mask_vals)]
data[mask_name] = tf.concat(m, axis=0)
return data
return _masked_concat
@Registry.register('preprocess_ops.strfmt')
def get_strfmt(template, outkey='text'):
"""Formats a string template with content form the data dict."""
def _template(data):
outputs = []
parts = string.Formatter().parse(template)
for (literal_text, field_name, format_spec, conversion) in parts:
# For now, we keep it simple and don't support fancy format specs.
# But we can add support to that via py_func as soon as we need it.
assert not format_spec and not conversion
outputs.append(tf.constant(literal_text))
if field_name:
value = data[field_name]
# Convert any non-strings (numbers, vectors) to a string.
if tf.convert_to_tensor(value).dtype != tf.string:
value = tf.strings.format('{}', value, summarize=-1)
outputs.append(value)
data[outkey] = tf.strings.join(outputs)
return data
return _template
@Registry.register('preprocess_ops.strjoin')
@utils.InKeyOutKey()
def get_strjoin(glue):
def _strjoin(x):
return tf.strings.reduce_join(x, separator=glue)
return _strjoin
@Registry.register('preprocess_ops.majority')
@utils.InKeyOutKey()
def get_majority():
def _majority(x):
val, _, count = tf.unique_with_counts(x) # Sadly, stablesorted.
return val[tf.argmax(count)]
return _majority
@Registry.register('preprocess_ops.getidx')
def getidx(inkey, index_key, outkey=None):
"""Indexes a tensor and stores result in outkey."""
def _getidx(data):
idx = data[index_key]
array = data[inkey]
data[outkey or inkey] = array[idx]
return data
return _getidx