|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Trains a CapPa model (https://arxiv.org/abs/2306.07915) on coco_captions. |
|
|
|
This config is for reference, we never ran a full training on a large |
|
image/text data set on public infrastructure. |
|
|
|
big_vision.trainers.proj.cappa.generative \ |
|
--config big_vision/configs/proj/cappa/pretrain.py \ |
|
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` |
|
""" |
|
|
|
|
|
from big_vision.configs import common_fewshot |
|
import big_vision.configs.common as bvcc |
|
import ml_collections |
|
|
|
|
|
def get_config(arg=None): |
|
"""Returns the base config.""" |
|
config = bvcc.parse_arg(arg, |
|
runlocal=False, |
|
total_steps=366_500, |
|
batch_size=8*1024, |
|
warmup_steps=10_000, |
|
) |
|
|
|
config.evals = {} |
|
config.input = {} |
|
config.input.batch_size = config.batch_size if not config.runlocal else 8 |
|
shuffle_buffer_size = 50_000 if not config.runlocal else 50 |
|
|
|
res = 224 |
|
patch_size = 16 |
|
max_text_tokens = 64 |
|
|
|
pp_image = (f'resize({res})|value_range(-1,1)') |
|
|
|
def tokenizer(inkey, outkey): |
|
return (f'tokenize(max_len={max_text_tokens}, model="c4_en", ' |
|
f'eos="sticky", inkey="{inkey}", outkey="{outkey}")') |
|
|
|
pp_coco = (f'decode|{pp_image}|' |
|
'coco_captions("captions")|choice(inkey="captions", outkey="text")|' |
|
f'{tokenizer("text", "labels")}|keep("image", "labels")') |
|
config.input.pp = pp_coco |
|
|
|
|
|
|
|
config.input.data = dict(name='coco_captions', split='train') |
|
config.input.shuffle_buffer_size = shuffle_buffer_size |
|
|
|
config.evals.val_coco = { |
|
'type': 'proj.cappa.perplexity', |
|
'pred': 'perplexity', |
|
'log_steps': 1000, |
|
'data': dict(name='coco_captions', split='val'), |
|
'pp_fn': pp_coco, |
|
} |
|
|
|
|
|
config.evals.fewshot = common_fewshot.get_fewshot_lsr( |
|
target_resolution=res, resize_resolution=int(256 / 224 * res)) |
|
config.evals.fewshot.type = 'fewshot_lsr' |
|
config.evals.fewshot.log_steps = 5_000 if not config.runlocal else 5 |
|
config.evals.fewshot.representation_layer = 'pre_logits' |
|
config.evals.fewshot.pred = 'enc_rep' |
|
config.evals.fewshot.pp_eval = config.evals.fewshot.pp_train |
|
|
|
|
|
|
|
config.evals['imagenet/scoring'] = dict( |
|
type='proj.cappa.scoring_classifier', |
|
pred='score', |
|
log_percent=0.1, |
|
data=dict(name='imagenet2012', split='validation'), |
|
pp_fn=f'decode|{pp_image}|keep("image", "label")', |
|
pp_txt=tokenizer('label', 'labels'), |
|
) |
|
|
|
for e in config.evals.values(): |
|
e.skip_first = True |
|
|
|
config.log_training_steps = 50 |
|
config.ckpt_steps = 1000 |
|
config.keep_ckpt_steps = None |
|
|
|
|
|
config.model_name = 'proj.cappa.cappa' |
|
config.model = ml_collections.ConfigDict() |
|
config.model.num_layers = 12 |
|
config.model.num_heads = 12 |
|
config.model.mlp_dim = 3072 |
|
config.model.emb_dim = 768 |
|
config.model.vocab_size = 32_000 |
|
config.model.patches = (patch_size, patch_size) |
|
config.model.seq_len = max_text_tokens |
|
config.model.posemb_type = 'learn' |
|
|
|
|
|
config.model.decoder_num_layers = 6 |
|
|
|
config.model.decoder_num_heads = 0 |
|
config.model.decoder_mlp_dim = 0 |
|
config.model.decoder_emb_dim = 0 |
|
config.model.dec_dropout_rate = 0.0 |
|
config.model.masked_pred_prob = 0.75 |
|
config.model.masking_ratio = 1.0 |
|
config.model.decoder_bias = False |
|
|
|
config.optax_name = 'big_vision.scale_by_adafactor' |
|
config.optax = dict(beta2_cap=0.999) |
|
config.grad_clip_norm = 1.0 |
|
config.label_smoothing = 0.0 |
|
|
|
schedule = dict(decay_type='cosine', |
|
warmup_steps=config.warmup_steps |
|
if not config.runlocal else 5) |
|
|
|
|
|
config.lr = 0.001 |
|
config.wd = 0.0001 |
|
config.schedule = schedule |
|
|
|
config.seed = 0 |
|
|
|
return config |