# Copyright 2024 Big Vision Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Common things across all transfer configs.""" | |
TOKENIZER = 'gemma(tokensets=("loc", "seg"))' | |
def tok(**kw): | |
"""Creates the tokenization preprocessing string.""" | |
# Single entry point so that it's consistent everywhere and easier to switch. | |
kw.setdefault('model', TOKENIZER) | |
kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items()) | |
return f'tok({kw})' | |
def combine_and_keep_train(text_len, before=(), sep='\n'): | |
return '|'.join([ | |
*before, | |
tok(key='prefix', bos='yes'), | |
tok(key='suffix', eos='yes'), | |
tok(key='septok', text=sep), | |
# If masks confuse you, see (internal link) | |
'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])', # pylint: disable=line-too-long | |
# For training, we +1 since the trainer removes EOS. | |
f'tolen({text_len+1}, pad_value=0, key="text")', # Value doesn't matter. | |
f'tolen({text_len+1}, pad_value=1, key="mask_ar")', | |
f'tolen({text_len+1}, pad_value=0, key="mask_loss")', | |
'keep("image", "text", "mask_ar", "mask_loss")', | |
]) | |
def combine_and_keep_eval(text_len, keep=tuple(), before=(), sep='\n'): | |
return '|'.join([ | |
*before, | |
# Same as training, except that suffix is now the empty string. | |
# Meaning, we create text as [prefix separator pad], | |
# and the mask accordingly as [0 0 1] (with repeats of respective lengths) | |
tok(key='prefix', bos='yes'), | |
tok(key='septok', text=sep), | |
# At eval time, there can be also a suffix key in the data. If so it is | |
# tokenized without EOS and decoding will continue from it. | |
'setdefault("suffix", "")', | |
tok(key='suffix', eos='no'), | |
# If masks confuse you, see (internal link) | |
'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long | |
f'tolen({text_len}, pad_value=0, key="text")', # value doesn't matter. | |
f'tolen({text_len}, pad_value=1, key="mask_ar")', | |
f'tolen({text_len}, pad_value=0, key="mask_input")', | |
# And we need to keep everything that makes our evaluator happy. | |
'keep(' + ', '.join(f'"{x}"' for x in ( | |
'image', 'text', 'mask_ar', 'mask_input') + tuple(keep)) + ')', | |
]) | |