|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045 |
|
|
|
IMPORTANT NOTE: This config uses coco_captions by default for demonstration |
|
purposes since the TFDS catalog does not provide any large image/alt-text data |
|
set; the training will not produce a model with useful accuracy. Please |
|
replace the data set below (marked by a comment) with an appropriate image/ |
|
alt-text data set wrapped in TFDS (for example LAION-400M) and run the config |
|
with the suffix `:test_with_coco=False` to train on your data set. Refer to |
|
the following guide to build a TFDS wrapper for your favorite image/alt-text |
|
data set: |
|
https://www.tensorflow.org/datasets/add_dataset |
|
|
|
Also note that evaluation on ImageNet requires manual TFDS setup, see |
|
https://github.com/google-research/big_vision#preparing-tfds-data |
|
|
|
|
|
Example training: |
|
|
|
big_vision.trainers.proj.image_text.contrastive \ |
|
--config big_vision/configs/proj/clippo/train_clippo.py \ |
|
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'` |
|
|
|
""" |
|
|
|
import big_vision.configs.common as bvcc |
|
from big_vision.configs.common_fewshot import get_fewshot_lsr |
|
from big_vision.configs.proj.image_text import common |
|
from ml_collections import ConfigDict |
|
|
|
|
|
def get_config(arg=None): |
|
"""The base configuration.""" |
|
arg = bvcc.parse_arg( |
|
arg, res=224, runlocal=False, variant='B/16', |
|
test_with_coco=True, i1k_eval=False) |
|
config = ConfigDict() |
|
|
|
config.input = {} |
|
if arg.test_with_coco: |
|
|
|
config.input.data = dict(name='coco_captions', split='train') |
|
val_data = dict(config.input.data) |
|
val_data['split'] = 'val' |
|
config.input.batch_size = 4000 if not arg.runlocal else 32 |
|
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50 |
|
config.total_steps = 400 if not arg.runlocal else 10 |
|
else: |
|
|
|
config.input.data = None |
|
val_data = None |
|
assert config.input.data is not None and val_data is not None, ( |
|
config.input.data, val_data) |
|
|
|
|
|
|
|
config.input.batch_size = 8 * 1024 if not arg.runlocal else 32 |
|
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50 |
|
config.total_steps = 100_000 if not arg.runlocal else 10 |
|
|
|
def tokenizer(inkey, outkey='labels'): |
|
return (f'render_unifont(' |
|
f'inkey="{inkey}", ' |
|
f'outkey="{outkey}", ' |
|
f'image_size={arg.res}, ' |
|
f'lower=True, ' |
|
f'font_size=16, ' |
|
f'text_brightness=0, ' |
|
f'background_brightness=127)|' |
|
f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")') |
|
|
|
pp_image = f'decode|resize({arg.res})|value_range(-1,1)' |
|
if arg.test_with_coco: |
|
|
|
pp_image_aug = ( |
|
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)') |
|
config.input.pp = pp_eval = ( |
|
f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|' |
|
f'keep("image", "labels")') |
|
else: |
|
config.input.pp = pp_eval = ( |
|
f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")') |
|
|
|
config.pp_modules = [ |
|
'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops'] |
|
|
|
config.log_training_steps = 50 |
|
config.ckpt_steps = 1000 |
|
config.keep_ckpt_steps = 5000 |
|
|
|
config.loss_use_global_batch = True |
|
|
|
|
|
config.model_name = 'proj.clippo.one_tower' |
|
|
|
config.model = ConfigDict() |
|
config.model.image_model = 'vit' |
|
config.model.image = ConfigDict({ |
|
'variant': arg.variant, |
|
'pool_type': 'map', |
|
'head_zeroinit': False, |
|
}) |
|
|
|
if arg.test_with_coco: |
|
|
|
assert arg.variant == 'B/16', arg.variant |
|
config.model_init = {'image': 'howto-i21k-B/16'} |
|
config.model_load = {} |
|
config.model_load['img_load_kw'] = { |
|
'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']} |
|
|
|
config.model.temperature_init = 10.0 |
|
config.model.out_dim = 768 |
|
|
|
|
|
config.optax_name = 'big_vision.scale_by_adafactor' |
|
config.grad_clip_norm = 1.0 |
|
|
|
if arg.test_with_coco: |
|
|
|
config.lr = 0.0001 |
|
config.wd = 0.0003 |
|
config.schedule = dict(decay_type='rsqrt', |
|
timescale=100, |
|
warmup_steps=100 if not arg.runlocal else 5, |
|
cooldown_steps=100 if not arg.runlocal else 5) |
|
else: |
|
config.lr = 0.001 |
|
config.wd = 0.0001 |
|
config.schedule = dict(decay_type='rsqrt', |
|
timescale=10_000, |
|
warmup_steps=10_000 if not arg.runlocal else 5, |
|
cooldown_steps=10_000 if not arg.runlocal else 5) |
|
|
|
|
|
eval_common = dict( |
|
type='proj.image_text.contrastive', |
|
use_global_batch=config.loss_use_global_batch, |
|
log_steps=1000 if not arg.runlocal else 5, |
|
) |
|
config.evals = {} |
|
sub = '[:4]' if arg.runlocal else '' |
|
config.evals.val = { |
|
**eval_common, |
|
'data': val_data, |
|
'pp_fn': pp_eval, |
|
} |
|
config.evals.coco = { |
|
**eval_common, |
|
'data': dict(name='coco_captions', split=f'val{sub}'), |
|
'pp_fn': ( |
|
f'{pp_image}|flatten|{tokenizer("captions/text")}|' |
|
f'keep("image", "labels")'), |
|
} |
|
|
|
if arg.i1k_eval: |
|
|
|
|
|
config.evals.imagenet = { |
|
**eval_common, |
|
'data': dict(name='imagenet2012', split=f'validation{sub}'), |
|
'pp_fn': ( |
|
f'{pp_image}|clip_i1k_label_names|' |
|
f'{tokenizer("labels")}|keep("image", "labels")'), |
|
} |
|
config.evals.disclf = dict( |
|
type='proj.image_text.discriminative_classifier', |
|
pp_txt=tokenizer('texts', 'labels'), |
|
prefix='z/0shot/', |
|
log_steps=5_000 if not arg.runlocal else 5) |
|
|
|
config.evals.retrieval_coco = common.get_coco( |
|
pp_img=f'resize({arg.res})|value_range(-1, 1)', |
|
pp_txt=tokenizer('texts'), |
|
log_steps=5_000 if not arg.runlocal else 5, |
|
) |
|
|
|
|
|
config.evals.fewshot = get_fewshot_lsr() |
|
config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5 |
|
config.evals.fewshot.representation_layer = 'img/pre_logits' |
|
|
|
config.seed = 0 |
|
|
|
return config |
|
|