|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""PaliGemma transfer to COCO-35L captions. |
|
""" |
|
|
|
import big_vision.configs.common as bvcc |
|
from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER |
|
|
|
LANGUAGES = ( |
|
'ar', 'bn', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fil', 'fr', |
|
'he', 'hi', 'hr', 'hu', 'id', 'it', 'ja', 'ko', 'mi', 'nl', 'no', 'pl', |
|
'pt', 'ro', 'ru', 'sv', 'sw', 'te', 'th', 'tr', 'uk', 'vi', 'zh', |
|
) |
|
|
|
LANGUAGES_XM3600 = ( |
|
'ar', 'bn', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fil', 'fr', |
|
'he', 'hi', 'hr', 'hu', 'id', 'it', 'ja', 'ko', 'mi', 'nl', 'no', 'pl', |
|
'pt', 'quz', 'ro', 'ru', 'sv', 'sw', 'te', 'th', 'tr', 'uk', 'vi', 'zh' |
|
) |
|
|
|
|
|
LANGUAGES_SUBSET = ('ar', 'bn', 'en', 'id', 'sw', 'tr', 'zh') |
|
|
|
|
|
def training_data(res, lang=None, text_len=32, crop='rs'): |
|
"""Creates training data config. |
|
|
|
See (internal link) |
|
You can add more arguments beside `res`, but give them good defaults. |
|
|
|
Args: |
|
res: The requested image resolution (eg 224) |
|
lang: language code |
|
text_len: sequence length |
|
crop: one of {'ic', 'rc', 'rs'} |
|
|
|
Returns: |
|
The ConfigDict for the input section. |
|
""" |
|
c = bvcc.parse_arg('') |
|
c.data = dict( |
|
name='coco35l', |
|
split=f'train_{lang}' if lang else '+'.join((f'train_{l}' for l in LANGUAGES)), |
|
) |
|
|
|
if crop == 'ic': |
|
crop = f'inception_crop({res}, area_min=50)' |
|
elif crop == 'rc': |
|
crop = f'resize_small({res*8//7})|random_crop({res})' |
|
elif crop == 'rs': |
|
crop = f'resize({res})' |
|
else: |
|
raise ValueError(f'Unknown crop: {crop}') |
|
|
|
c.pp = '|'.join([ |
|
'flatten', |
|
'decode', crop, 'value_range(-1, 1)', |
|
'choice_no_replacement(inkey="captions", outkey="suffix")', |
|
'strfmt("caption {language}", outkey="prefix")', |
|
combine_and_keep_train(text_len), |
|
]) |
|
return c |
|
|
|
|
|
def _get_eval_pp(res, lang, text_len=32): |
|
return '|'.join([ |
|
'flatten', |
|
'decode', f'resize({res})', 'value_range(-1, 1)', |
|
f'strfmt("caption {lang}", outkey="prefix")', |
|
combine_and_keep_eval(text_len, keep=('image/id', 'captions')), |
|
]) |
|
|
|
|
|
def add_eval(c, res, text_len=32, langs=None, **kw): |
|
"""Captioning evaluator with cider/bleu-4/meteor/rouge/spice metrics.""" |
|
for lang in (langs or LANGUAGES): |
|
|
|
freq = 0.25 if lang in LANGUAGES_SUBSET else 1.0 |
|
|
|
c.evals[f'coco35l/{lang}'] = dict( |
|
type='proj.paligemma.transfers.coco_caption', |
|
pred='decode', pred_kw={'max_decode_len': text_len}, |
|
log_percent=freq, skip_first=(freq == 1.0), tokenizer=TOKENIZER, |
|
data=dict( |
|
name='coco35l', |
|
split=f'dev_{lang}', |
|
), |
|
cache='none', |
|
pp_fn=_get_eval_pp(res, lang, text_len), |
|
) |
|
c.evals[f'coco35l/{lang}'].update(kw) |
|
|
|
|
|
def add_eval_xm(c, res, text_len=32, langs=None, **kw): |
|
"""Captioning evaluator with cider/bleu-4/meteor/rouge/spice metrics.""" |
|
for lang in (langs or LANGUAGES_XM3600): |
|
|
|
freq = 0.25 if lang in LANGUAGES_SUBSET else 1.0 |
|
|
|
c.evals[f'xm3600/{lang}'] = dict( |
|
type='proj.paligemma.transfers.coco_caption', |
|
pred='decode', pred_kw={'max_decode_len': text_len}, |
|
log_percent=freq, skip_first=(freq == 1.0), tokenizer=TOKENIZER, |
|
data=dict( |
|
name='xm3600', |
|
split=lang, |
|
), |
|
pp_fn=_get_eval_pp(res, lang, text_len) |
|
) |
|
c.evals[f'xm3600/{lang}'].update(kw) |
|
|
|
|
|
def add_eval_pplx(c, res, text_len=32): |
|
"""Perplexity evaluator to test runs before implementing the real deal.""" |
|
c_train = training_data(res, text_len=text_len) |
|
for name, split in [ |
|
('minitrain', 'train_en[:2%]'), |
|
('minival', 'dev_en[:5%]'), |
|
('eval', 'dev_en'), |
|
]: |
|
c.evals[f'coco35l/{name}/pplx'] = dict( |
|
type='proj.paligemma.perplexity', pred='logits', |
|
key='text', shift_labels=True, |
|
log_percent=0.05, |
|
data={**c_train.data, 'split': split}, |
|
pp_fn=c_train.pp, |
|
) |
|
|
|
|
|
def get_config(arg=None): |
|
"""Config for training.""" |
|
c = bvcc.parse_arg(arg, mode='xm', crop='rs', res=224, eval_xm3600=True, beam_size=0) |
|
|
|
c.input = { |
|
lang: training_data(c.res, lang=lang, crop=c.crop) |
|
for lang in LANGUAGES |
|
} |
|
c.input.data = {lang: 1 for lang in LANGUAGES} |
|
for k in c.input.data: |
|
c.input[k].shuffle_buffer_size = 10_000 |
|
|
|
c.total_examples = 566_435 |
|
c.input.batch_size = 256 |
|
c.optax_name = 'scale_by_adam' |
|
c.optax = dict(b2=0.999) |
|
c.lr = 1e-4 |
|
c.wd = 0.0 |
|
c.grad_clip_norm = 1.0 |
|
c.label_smoothing = 0.0 |
|
c.schedule = dict(decay_type='cosine', warmup_percent=0.05) |
|
|
|
|
|
c.evals = {} |
|
add_eval_pplx(c, c.res) |
|
|
|
if c.beam_size: |
|
decode_kw = {'pred': 'beam_decode', 'pred_kw': {'beam_size': c.beam_size}} |
|
else: |
|
decode_kw = {} |
|
|
|
add_eval(c, c.res, batch_size=1024, **decode_kw) |
|
if c.eval_xm3600: |
|
add_eval_xm(c, c.res, batch_size=1024, **decode_kw) |
|
|
|
|
|
c.model_name = 'proj.paligemma.paligemma' |
|
c.model = {} |
|
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True) |
|
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0) |
|
c.model_init = f'pt_{c.res}' |
|
|
|
|
|
c.mesh = [('data', -1)] |
|
c.sharding_strategy = [('.*', 'fsdp(axis="data")')] |
|
c.sharding_rules = [('act_batch', ('data',))] |
|
|
|
|
|
c.log_training_steps = 50 |
|
c.ckpt_steps = 1_000 |
|
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops'] |
|
|
|
|
|
if c.mode in ('runlocal', 'mock'): |
|
|
|
for ev in c.evals.values(): |
|
ev.data.split = ev.data.split.split('[')[0] + '[:16]' |
|
|
|
if c.mode == 'runlocal': |
|
c.log_training_steps = 1 |
|
c.input.batch_size = 2 |
|
|
|
c.seed = 0 |
|
return c |
|
|
|
|
|
def sweep_hyper(add): |
|
"""sweep over hyper-params.""" |
|
for lr in (1e-5, 3e-6, 1e-6): |
|
for wd in (0.0, 0.1*lr): |
|
for ep in (1, 3, 5, 10, 20): |
|
|
|
add(lr=lr, wd=wd, total_examples=ep * 566_435, **bvcc.arg(res=224)) |
|
|
|
|
|
def sweep_best(add, arg=None): |
|
"""Train with best hyper-params.""" |
|
c = bvcc.parse_arg(arg, eval_xm3600=True) |
|
ep = 566_435 |
|
add(lr=1e-5, wd=1e-6, total_examples=5 * ep, **bvcc.arg(res=224, **c)) |
|
add(lr=1e-5, wd=1e-6, total_examples=5 * ep, **bvcc.arg(res=448, **c)) |
|
|
|
|
|
sweep = sweep_best |
|
|
|
|
|
def metrics(arg=None): |
|
c = bvcc.parse_arg(arg, eval_xm3600=True) |
|
m = [('epoch', f'coco35l/{lang}/cider') for lang in LANGUAGES] |
|
if c.eval_xm3600: |
|
for lang in LANGUAGES: |
|
m.append(('epoch', f'xm3600/{lang}/cider')) |
|
return m |
|
|