File size: 7,163 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045
IMPORTANT NOTE: This config uses coco_captions by default for demonstration
purposes since the TFDS catalog does not provide any large image/alt-text data
set; the training will not produce a model with useful accuracy. Please
replace the data set below (marked by a comment) with an appropriate image/
alt-text data set wrapped in TFDS (for example LAION-400M) and run the config
with the suffix `:test_with_coco=False` to train on your data set. Refer to
the following guide to build a TFDS wrapper for your favorite image/alt-text
data set:
https://www.tensorflow.org/datasets/add_dataset
Also note that evaluation on ImageNet requires manual TFDS setup, see
https://github.com/google-research/big_vision#preparing-tfds-data
Example training:
big_vision.trainers.proj.image_text.contrastive \
--config big_vision/configs/proj/clippo/train_clippo.py \
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'`
"""
import big_vision.configs.common as bvcc
from big_vision.configs.common_fewshot import get_fewshot_lsr
from big_vision.configs.proj.image_text import common
from ml_collections import ConfigDict
def get_config(arg=None):
"""The base configuration."""
arg = bvcc.parse_arg(
arg, res=224, runlocal=False, variant='B/16',
test_with_coco=True, i1k_eval=False)
config = ConfigDict()
config.input = {}
if arg.test_with_coco:
# Use COCO Captions for sanity-checking
config.input.data = dict(name='coco_captions', split='train')
val_data = dict(config.input.data)
val_data['split'] = 'val'
config.input.batch_size = 4000 if not arg.runlocal else 32
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50
config.total_steps = 400 if not arg.runlocal else 10
else:
# Please add your favorite image/alt-text dataset here
config.input.data = None
val_data = None
assert config.input.data is not None and val_data is not None, (
config.input.data, val_data)
# The value in the paper is 10 * 1024, which requires 128 TPUv3 cores or a
# memory optimized ViT implementation when running on 128 TPUv2 cores.
config.input.batch_size = 8 * 1024 if not arg.runlocal else 32
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50
config.total_steps = 100_000 if not arg.runlocal else 10
def tokenizer(inkey, outkey='labels'):
return (f'render_unifont('
f'inkey="{inkey}", '
f'outkey="{outkey}", '
f'image_size={arg.res}, '
f'lower=True, '
f'font_size=16, '
f'text_brightness=0, '
f'background_brightness=127)|'
f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")')
pp_image = f'decode|resize({arg.res})|value_range(-1,1)'
if arg.test_with_coco:
# Train with augmentation when sanity-checking
pp_image_aug = (
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)')
config.input.pp = pp_eval = (
f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|'
f'keep("image", "labels")')
else:
config.input.pp = pp_eval = (
f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")')
config.pp_modules = [
'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops']
config.log_training_steps = 50
config.ckpt_steps = 1000
config.keep_ckpt_steps = 5000
config.loss_use_global_batch = True
# Define the model
config.model_name = 'proj.clippo.one_tower'
config.model = ConfigDict()
config.model.image_model = 'vit'
config.model.image = ConfigDict({
'variant': arg.variant,
'pool_type': 'map',
'head_zeroinit': False,
})
if arg.test_with_coco:
# Initialize with ImageNet21k pretrained checkpoint for sanity-checking
assert arg.variant == 'B/16', arg.variant
config.model_init = {'image': 'howto-i21k-B/16'}
config.model_load = {}
config.model_load['img_load_kw'] = {
'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']}
config.model.temperature_init = 10.0
config.model.out_dim = 768
# Define the optimizer
config.optax_name = 'big_vision.scale_by_adafactor'
config.grad_clip_norm = 1.0
if arg.test_with_coco:
# Short schedule for sanity-checking
config.lr = 0.0001
config.wd = 0.0003
config.schedule = dict(decay_type='rsqrt',
timescale=100,
warmup_steps=100 if not arg.runlocal else 5,
cooldown_steps=100 if not arg.runlocal else 5)
else:
config.lr = 0.001
config.wd = 0.0001
config.schedule = dict(decay_type='rsqrt',
timescale=10_000,
warmup_steps=10_000 if not arg.runlocal else 5,
cooldown_steps=10_000 if not arg.runlocal else 5)
# Eval section (Both few-shot and zero-shot)
eval_common = dict(
type='proj.image_text.contrastive',
use_global_batch=config.loss_use_global_batch,
log_steps=1000 if not arg.runlocal else 5,
)
config.evals = {}
sub = '[:4]' if arg.runlocal else ''
config.evals.val = {
**eval_common,
'data': val_data,
'pp_fn': pp_eval,
}
config.evals.coco = {
**eval_common,
'data': dict(name='coco_captions', split=f'val{sub}'),
'pp_fn': (
f'{pp_image}|flatten|{tokenizer("captions/text")}|'
f'keep("image", "labels")'),
}
if arg.i1k_eval:
# Requires manual download, see
# https://github.com/google-research/big_vision#preparing-tfds-data
config.evals.imagenet = {
**eval_common,
'data': dict(name='imagenet2012', split=f'validation{sub}'),
'pp_fn': (
f'{pp_image}|clip_i1k_label_names|'
f'{tokenizer("labels")}|keep("image", "labels")'),
}
config.evals.disclf = dict(
type='proj.image_text.discriminative_classifier',
pp_txt=tokenizer('texts', 'labels'),
prefix='z/0shot/',
log_steps=5_000 if not arg.runlocal else 5)
config.evals.retrieval_coco = common.get_coco(
pp_img=f'resize({arg.res})|value_range(-1, 1)',
pp_txt=tokenizer('texts'),
log_steps=5_000 if not arg.runlocal else 5,
)
# Few-shot metrics
config.evals.fewshot = get_fewshot_lsr()
config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5
config.evals.fewshot.representation_layer = 'img/pre_logits'
config.seed = 0
return config
|