pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common things across all transfer configs."""
TOKENIZER = 'gemma(tokensets=("loc", "seg"))'
def tok(**kw):
"""Creates the tokenization preprocessing string."""
# Single entry point so that it's consistent everywhere and easier to switch.
kw.setdefault('model', TOKENIZER)
kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items())
return f'tok({kw})'
def combine_and_keep_train(text_len, before=(), sep='\n'):
return '|'.join([
*before,
tok(key='prefix', bos='yes'),
tok(key='suffix', eos='yes'),
tok(key='septok', text=sep),
# If masks confuse you, see (internal link)
'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])', # pylint: disable=line-too-long
# For training, we +1 since the trainer removes EOS.
f'tolen({text_len+1}, pad_value=0, key="text")', # Value doesn't matter.
f'tolen({text_len+1}, pad_value=1, key="mask_ar")',
f'tolen({text_len+1}, pad_value=0, key="mask_loss")',
'keep("image", "text", "mask_ar", "mask_loss")',
])
def combine_and_keep_eval(text_len, keep=tuple(), before=(), sep='\n'):
return '|'.join([
*before,
# Same as training, except that suffix is now the empty string.
# Meaning, we create text as [prefix separator pad],
# and the mask accordingly as [0 0 1] (with repeats of respective lengths)
tok(key='prefix', bos='yes'),
tok(key='septok', text=sep),
# At eval time, there can be also a suffix key in the data. If so it is
# tokenized without EOS and decoding will continue from it.
'setdefault("suffix", "")',
tok(key='suffix', eos='no'),
# If masks confuse you, see (internal link)
'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long
f'tolen({text_len}, pad_value=0, key="text")', # value doesn't matter.
f'tolen({text_len}, pad_value=1, key="mask_ar")',
f'tolen({text_len}, pad_value=0, key="mask_input")',
# And we need to keep everything that makes our evaluator happy.
'keep(' + ', '.join(f'"{x}"' for x in (
'image', 'text', 'mask_ar', 'mask_input') + tuple(keep)) + ')',
])