|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Pre-train ViT-g (1B params) on JFT-3B as in https://arxiv.org/abs/2106.04560 |
|
|
|
To train ViT-G (2B params), simply update the following single line: |
|
`config.model.variant = 'G/14'` |
|
|
|
The code is released for reference purposes. |
|
One can test the code using public ImageNet-1k or ImageNet-21k dataset. |
|
|
|
big_vision.train \ |
|
--config big_vision/configs/proj/scaling_laws/train_vit_g.py \ |
|
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` |
|
|
|
""" |
|
from big_vision.configs.common_fewshot import get_fewshot_lsr |
|
import ml_collections as mlc |
|
|
|
|
|
def get_config(): |
|
"""Rocket config.""" |
|
config = mlc.ConfigDict() |
|
|
|
config.dataset = 'jft_3b' |
|
config.val_split = 'val' |
|
config.train_split = 'train' |
|
config.num_classes = 29_593 |
|
config.init_head_bias = -10.0 |
|
|
|
|
|
config.batch_size = 4096*4 |
|
|
|
pp_common = '|value_range(-1, 1)' |
|
pp_common += f'|onehot({config.num_classes})' |
|
pp_common += '|keep("image", "labels")' |
|
config.pp_train = 'inception_crop(224)|flip_lr' + pp_common |
|
config.pp_eval = 'resize_small(256)|central_crop(224)' + pp_common |
|
config.shuffle_buffer_size = 250_000 |
|
|
|
config.log_training_steps = 50 |
|
config.log_eval_steps = 1000 |
|
|
|
|
|
config.ckpt_steps = 1000 |
|
config.keep_ckpt_steps = 10_000 |
|
|
|
config.prefetch_to_device = 1 |
|
config.trial = 0 |
|
|
|
|
|
config.model_name = 'vit' |
|
config.model = mlc.ConfigDict() |
|
config.model.variant = 'g/14' |
|
config.model.pool_type = 'map' |
|
|
|
|
|
config.optax_name = 'big_vision.scale_by_adafactor' |
|
config.grad_clip_norm = 1.0 |
|
config.lr = 8e-4 |
|
config.wd = 0.03 * 8e-4 |
|
config.wd_mults = [ |
|
('.*head/kernel', 100.0), |
|
('.*/kernel', 1.0), |
|
] |
|
config.schedule = dict( |
|
decay_type='rsqrt', timescale=10_000, warmup_steps=10_000, |
|
cooldown_steps=50_000) |
|
config.total_steps = 1_000_000 |
|
|
|
|
|
config.evals = {} |
|
config.evals.fewshot = dict(log_steps=10_000, **get_fewshot_lsr()) |
|
|
|
return config |
|
|