File size: 7,163 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=line-too-long
r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045

IMPORTANT NOTE: This config uses coco_captions by default for demonstration
purposes since the TFDS catalog does not provide any large image/alt-text data
set; the training will not produce a model with useful accuracy. Please
replace the data set below (marked by a comment) with an appropriate image/
alt-text data set wrapped in TFDS (for example LAION-400M) and run the config
with the suffix `:test_with_coco=False` to train on your data set. Refer to
the following guide to build a TFDS wrapper for your favorite image/alt-text
data set:
https://www.tensorflow.org/datasets/add_dataset

Also note that evaluation on ImageNet requires manual TFDS setup, see
https://github.com/google-research/big_vision#preparing-tfds-data


Example training:

big_vision.trainers.proj.image_text.contrastive \
    --config big_vision/configs/proj/clippo/train_clippo.py \
    --workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'`

"""

import big_vision.configs.common as bvcc
from big_vision.configs.common_fewshot import get_fewshot_lsr
from big_vision.configs.proj.image_text import common
from ml_collections import ConfigDict


def get_config(arg=None):
  """The base configuration."""
  arg = bvcc.parse_arg(
      arg, res=224, runlocal=False, variant='B/16',
      test_with_coco=True, i1k_eval=False)
  config = ConfigDict()

  config.input = {}
  if arg.test_with_coco:
    # Use COCO Captions for sanity-checking
    config.input.data = dict(name='coco_captions', split='train')
    val_data = dict(config.input.data)
    val_data['split'] = 'val'
    config.input.batch_size = 4000 if not arg.runlocal else 32
    config.input.shuffle_buffer_size = 50_000  if not arg.runlocal else 50
    config.total_steps = 400 if not arg.runlocal else 10
  else:
    # Please add your favorite image/alt-text dataset here
    config.input.data = None
    val_data = None
    assert config.input.data is not None and val_data is not None, (
        config.input.data, val_data)

    # The value in the paper is 10 * 1024, which requires 128 TPUv3 cores or a
    # memory optimized ViT implementation when running on 128 TPUv2 cores.
    config.input.batch_size = 8 * 1024 if not arg.runlocal else 32
    config.input.shuffle_buffer_size = 250_000  if not arg.runlocal else 50
    config.total_steps = 100_000 if not arg.runlocal else 10

  def tokenizer(inkey, outkey='labels'):
    return (f'render_unifont('
            f'inkey="{inkey}", '
            f'outkey="{outkey}", '
            f'image_size={arg.res}, '
            f'lower=True, '
            f'font_size=16, '
            f'text_brightness=0, '
            f'background_brightness=127)|'
            f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")')

  pp_image = f'decode|resize({arg.res})|value_range(-1,1)'
  if arg.test_with_coco:
    # Train with augmentation when sanity-checking
    pp_image_aug = (
        f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)')
    config.input.pp = pp_eval = (
        f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|'
        f'keep("image", "labels")')
  else:
    config.input.pp = pp_eval = (
        f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")')

  config.pp_modules = [
      'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops']

  config.log_training_steps = 50
  config.ckpt_steps = 1000
  config.keep_ckpt_steps = 5000

  config.loss_use_global_batch = True

  # Define the model
  config.model_name = 'proj.clippo.one_tower'

  config.model = ConfigDict()
  config.model.image_model = 'vit'
  config.model.image = ConfigDict({
      'variant': arg.variant,
      'pool_type': 'map',
      'head_zeroinit': False,
  })

  if arg.test_with_coco:
    # Initialize with ImageNet21k pretrained checkpoint for sanity-checking
    assert arg.variant == 'B/16', arg.variant
    config.model_init = {'image': 'howto-i21k-B/16'}
    config.model_load = {}
    config.model_load['img_load_kw'] = {
        'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']}

  config.model.temperature_init = 10.0
  config.model.out_dim = 768

  # Define the optimizer
  config.optax_name = 'big_vision.scale_by_adafactor'
  config.grad_clip_norm = 1.0

  if arg.test_with_coco:
    # Short schedule for sanity-checking
    config.lr = 0.0001
    config.wd = 0.0003
    config.schedule = dict(decay_type='rsqrt',
                           timescale=100,
                           warmup_steps=100 if not arg.runlocal else 5,
                           cooldown_steps=100 if not arg.runlocal else 5)
  else:
    config.lr = 0.001
    config.wd = 0.0001
    config.schedule = dict(decay_type='rsqrt',
                           timescale=10_000,
                           warmup_steps=10_000 if not arg.runlocal else 5,
                           cooldown_steps=10_000 if not arg.runlocal else 5)

  # Eval section (Both few-shot and zero-shot)
  eval_common = dict(
      type='proj.image_text.contrastive',
      use_global_batch=config.loss_use_global_batch,
      log_steps=1000 if not arg.runlocal else 5,
  )
  config.evals = {}
  sub = '[:4]' if arg.runlocal else ''
  config.evals.val = {
      **eval_common,
      'data': val_data,
      'pp_fn': pp_eval,
  }
  config.evals.coco = {
      **eval_common,
      'data': dict(name='coco_captions', split=f'val{sub}'),
      'pp_fn': (
          f'{pp_image}|flatten|{tokenizer("captions/text")}|'
          f'keep("image", "labels")'),
  }

  if arg.i1k_eval:
    # Requires manual download, see
    # https://github.com/google-research/big_vision#preparing-tfds-data
    config.evals.imagenet = {
        **eval_common,
        'data': dict(name='imagenet2012', split=f'validation{sub}'),
        'pp_fn': (
            f'{pp_image}|clip_i1k_label_names|'
            f'{tokenizer("labels")}|keep("image", "labels")'),
    }
    config.evals.disclf = dict(
        type='proj.image_text.discriminative_classifier',
        pp_txt=tokenizer('texts', 'labels'),
        prefix='z/0shot/',
        log_steps=5_000 if not arg.runlocal else 5)

  config.evals.retrieval_coco = common.get_coco(
      pp_img=f'resize({arg.res})|value_range(-1, 1)',
      pp_txt=tokenizer('texts'),
      log_steps=5_000 if not arg.runlocal else 5,
  )

  # Few-shot  metrics
  config.evals.fewshot = get_fewshot_lsr()
  config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5
  config.evals.fewshot.representation_layer = 'img/pre_logits'

  config.seed = 0

  return config