File size: 7,548 Bytes
6ef31de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import ml_collections
def get_config(runlocal=''):
"""Returns the base experiment configuration."""
runlocal = bool(runlocal)
config = ml_collections.ConfigDict()
config.token_loss_coef = 1.
config.runlocal = runlocal
config.experiment_name = 'ytt'
config.count_flops = False if runlocal else ml_collections.ConfigDict(
{'count_flops': True})
# dataset
config.dataset_name = 'dense_video_captioning'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.corrupt = 0.25
config.dataset_configs.span_len = 5.
config.dataset_configs.proba_corrupt = 1.
config.dataset_configs.corrupt_coef = 1.
config.dataset_configs.preserve = False
notime = ml_collections.config_dict.FieldReference(False)
config.dataset_configs.notime = notime
config.dataset_configs.abs_time_token = False
config.dataset_configs.random_temporal_crop_proba = 1.
config.dataset_configs.time_format = 'se'
tmp_only = ml_collections.config_dict.FieldReference(False)
config.dataset_configs.tmp_only = tmp_only
config.dataset_configs.split = not runlocal
order = ml_collections.config_dict.FieldReference('ld')
config.dataset_configs.order = order
config.dataset_configs.from_xm = None
config.data_dtype_str = 'float32'
config.dataset_configs.base_dir = '/'
config.dataset_configs.base_dir = '/path/to/yttemporal'
config.dataset_configs.tables = {
'train': 'train.tfrecord.sst@1024',
}
config.dataset_configs.examples_per_subset = {
'train': 14780275,
}
# List of modalities to load, supports `features` only for now.
# Note that it only specifies which modalities to load, not which to use,
# which is controlled by config.model.modality_fusion
config.dataset_configs.modalities = ('features', 'text')
config.dataset_configs.features_dim = 768
config.dataset_configs.return_as_dict = True
num_frames = ml_collections.config_dict.FieldReference(100)
config.dataset_configs.num_frames = num_frames
num_bins = ml_collections.config_dict.FieldReference(100)
config.dataset_configs.num_bins = num_bins
config.dataset_configs.one_hot_labels = True
config.dataset_configs.zero_centering = True
config.dataset_configs.val_on_test = False
config.dataset_configs.num_eval_clips = 1
config.dataset_configs.prefetch_to_device = 2
# Text params
config.dataset_configs.max_num_output_words = 1000
config.dataset_configs.max_num_input_words = 1000
config.dataset_configs.tokenizer = ml_collections.ConfigDict()
config.dataset_configs.tokenizer.tokenizer_type = 'sentence_piece'
config.dataset_configs.caption_string = 'ASR/segment/label/string'
config.dataset_configs.train_caption_string = 'ASR/segment/label/string'
config.dataset_configs.input_timestamp_start_name = 'ASR/segment/start/timestamp'
config.dataset_configs.input_timestamp_end_name = 'ASR/segment/end/timestamp'
config.dataset_configs.input_duration_name = 'video/duration'
config.dataset_configs.output_raw_timestamp_name = 'timestamp'
config.dataset_configs.output_raw_duration_name = 'duration'
config.dataset_configs.input_feature_name = 'image/clip_embeddings'
config.dataset_configs.output_raw_feature_name = 'features'
config.dataset_configs.vocabulary_size = 32128
config.dataset_configs.max_events = 1100
config.dataset_configs.max_segments = 0
config.datasets = {'ytt': config.dataset_configs}
# Decoding
config.decoding = ml_collections.ConfigDict()
config.decoding.decoding_method = 'beamsearch'
config.decoding.num_decodes = 4
config.decoding.alpha = 0.6
config.decoding.temperature = 1.
# Model
config.model_name = 'vid2seq'
config.model = ml_collections.ConfigDict()
config.model.from_xm = None
# Encoder configs
config.model.encoder = ml_collections.ConfigDict()
config.model.encoder.share_encoder = True
config.model.encoder.encoder_type = 'cat_encoder'
config.model.encoder.cat_encoder = ml_collections.ConfigDict()
config.model.encoder.cat_encoder.dim = 2048
config.model.encoder.cat_encoder.layers = 12
config.model.encoder.cat_encoder.heads = 12
config.model.encoder.cat_encoder.pos_embed = 'learned_1d'
config.model.encoder.cat_encoder.dropout_rate = 0.1
config.model.encoder.cat_encoder.t5_dropout_rate = 0.1
config.model.encoder.cat_encoder.stochastic_depth = 0.
config.model.encoder.cat_encoder.pretrained_config = 't5_1_1_base'
config.model.encoder.from_xm = None
# Decoder configs
config.model.decoder_type = 't5_decoder'
config.model.decoder = ml_collections.ConfigDict()
config.model.decoder.order = order
config.model.decoder.t5_decoder = ml_collections.ConfigDict()
config.model.decoder.t5_decoder.logits_via_embedding = False
config.model.decoder.t5_decoder.dropout_rate = 0.1
config.model.decoder.t5_decoder.num_frames = num_frames
config.model.decoder.notime = notime
config.model.decoder.num_bins = num_bins
config.model.decoder.tmp_only = tmp_only
# Obtained from scenic/projects/t5/model.py.
config.model.decoder.t5_decoder.pretrained_config = 't5_1_1_base'
config.model.tmp_decoder_type = 't5_decoder'
config.model.tmp_decoder = ml_collections.ConfigDict()
config.model.tmp_decoder.t5_decoder = ml_collections.ConfigDict()
config.model.tmp_decoder.t5_decoder.logits_via_embedding = False
config.model.tmp_decoder.t5_decoder.dropout_rate = 0.
config.model.tmp_decoder.t5_decoder.pretrained_config = 't5_1_1_base'
config.model.decoder.t5_decoder.local = 5
# Initalisation configs
config.init_from = ml_collections.ConfigDict()
config.init_from.step = None
config.init_from.xm = None
config.init_from.encoder = ml_collections.ConfigDict()
config.init_from.encoder.checkpoint_path = None
config.init_from.encoder.init_from_vit = False
config.init_from.encoder = ml_collections.ConfigDict()
config.init_from.encoder.load_pretrained_weights = True
config.init_from.decoder = ml_collections.ConfigDict()
config.init_from.decoder.load_pretrained_weights = True
config.init_from.t5 = ml_collections.ConfigDict()
config.init_from.t5.load_pretrained_weights = True
# Training
config.trainer_name = 'densevidcap_trainer'
config.optimizer = 'adam'
config.optimizer_configs = ml_collections.ConfigDict()
config.optimizer_configs.weight_decay = 0.
config.l2_decay_factor = 0.
config.max_grad_norm = 0.1
config.label_smoothing = 0.1
epochs = ml_collections.config_dict.FieldReference(10)
config.num_training_epochs = 0
batch_size = ml_collections.config_dict.FieldReference(512)
config.batch_size = 1 if runlocal else batch_size # 128 # Minimum is num_devices = 32
config.eval_batch_size = 1 if runlocal else 128 # Needs to be num_local_devices
config.rng_seed = 0
# Learning schedule.
config.lr_configs = ml_collections.ConfigDict()
config.lr_configs.learning_rate_schedule = 'compound'
config.lr_configs.factors = 'constant * linear_warmup'
config.lr_configs.warmup_steps = 1000
config.lr_configs.base_learning_rate = 1e-4
config.eval_metrics = ['cider', 'meteor', 'soda']
# Logging
config.log_summary_steps = 500 # write TB and/or XM summary
config.checkpoint_steps = 5000
config.log_eval_steps = 5000
config.write_summary = True # write TB and/or XM summary
config.write_xm_measurements = True # write XM measurements
config.xprof = True # Profile using xprof
config.checkpoint = True # do checkpointing
config.debug_train = False # debug mode during training
config.debug_eval = False # debug mode during eval
return config
|