File size: 8,056 Bytes
6ef31de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import ml_collections
YOUCOOK_TRAIN_SIZE = 1333 # Number of videos
def get_config(runlocal=''):
"""Returns the base experiment configuration."""
runlocal = bool(runlocal)
config = ml_collections.ConfigDict()
config.token_loss_coef = 1.
config.runlocal = runlocal
config.experiment_name = 'youcook'
config.count_flops = False # if runlocal else ml_collections.ConfigDict({'count_flops': True})
# dataset
config.dataset_name = 'dense_video_captioning'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.corrupt = 0.
config.dataset_configs.span_len = 3.
config.dataset_configs.preserve = True
config.dataset_configs.corrupt_coef = 0.
config.dataset_configs.proba_corrupt = 0.
notime = ml_collections.config_dict.FieldReference(False)
config.dataset_configs.notime = notime
config.dataset_configs.abs_time_token = False
config.dataset_configs.random_temporal_crop_proba = 0.5
config.dataset_configs.time_format = 'se'
tmp_only = ml_collections.config_dict.FieldReference(False)
config.dataset_configs.tmp_only = tmp_only
config.dataset_configs.split = False
order = ml_collections.config_dict.FieldReference('ld')
config.dataset_configs.order = order
config.dataset_configs.from_xm = None
config.data_dtype_str = 'float32'
config.dataset_configs.base_dir = '/mnt/petrelfs/wangyiqin/vid_cap/examples'
config.dataset_configs.tables = {
'train': 'train.tfrecord.sst@64',
'validation': 'test@1',
}
config.dataset_configs.examples_per_subset = {
'train': 0,
'validation': 1,
}
# List of modalities to load, supports `features` only for now.
# Note that it only specifies which modalities to load, not which to use,
# which is controlled by config.model.modality_fusion
config.dataset_configs.modalities = ('features', 'text')
config.dataset_configs.features_dim = 768
config.dataset_configs.return_as_dict = True
num_frames = ml_collections.config_dict.FieldReference(
100) # need to change back to 100 in the future -- Yiqin
config.dataset_configs.num_frames = num_frames
num_bins = ml_collections.config_dict.FieldReference(100)
config.dataset_configs.num_bins = num_bins
config.dataset_configs.one_hot_labels = True
config.dataset_configs.zero_centering = True
config.dataset_configs.val_on_test = False
config.dataset_configs.num_eval_clips = 1
config.dataset_configs.prefetch_to_device = 2
# Text params
config.dataset_configs.max_num_output_words = 256
config.dataset_configs.max_num_input_words = 1000
config.dataset_configs.tokenizer = ml_collections.ConfigDict()
config.dataset_configs.tokenizer.tokenizer_type = 'sentence_piece'
config.dataset_configs.caption_string = 'caption/string'
config.dataset_configs.train_caption_string = 'caption/string'
config.dataset_configs.input_timestamp_name = 'video/timestamps'
config.dataset_configs.input_duration_name = 'video/duration'
config.dataset_configs.output_raw_timestamp_name = 'timestamp'
config.dataset_configs.output_raw_duration_name = 'duration'
config.dataset_configs.input_feature_name = 'image/clip_embeddings'
config.dataset_configs.output_raw_feature_name = 'features'
config.dataset_configs.vocabulary_size = 32128
config.dataset_configs.max_events = 20
config.dataset_configs.asr_notime = False
config.datasets = {'youcook': config.dataset_configs}
# Decoding
config.decoding = ml_collections.ConfigDict()
config.decoding.decoding_method = 'beamsearch'
# config.decoding.decoding_method = 'temperature_sample'
config.decoding.num_decodes = 4
config.decoding.alpha = 1
config.decoding.temperature = 1.
# Model
config.model_name = 'vid2seq'
config.model = ml_collections.ConfigDict()
config.model.from_xm = None
# Encoder configs
config.model.encoder = ml_collections.ConfigDict()
config.model.encoder.share_encoder = True
config.model.encoder.encoder_type = 'cat_encoder'
config.model.encoder.cat_encoder = ml_collections.ConfigDict()
config.model.encoder.cat_encoder.dim = 2048
config.model.encoder.cat_encoder.layers = 12
config.model.encoder.cat_encoder.heads = 12
config.model.encoder.cat_encoder.pos_embed = 'learned_1d'
config.model.encoder.cat_encoder.dropout_rate = 0.
config.model.encoder.cat_encoder.t5_dropout_rate = 0.1
config.model.encoder.cat_encoder.stochastic_depth = 0.
config.model.encoder.cat_encoder.pretrained_config = 't5_1_1_base'
config.model.encoder.from_xm = None
# Decoder configs
config.model.decoder_type = 't5_decoder'
config.model.decoder = ml_collections.ConfigDict()
config.model.decoder.order = order
config.model.decoder.t5_decoder = ml_collections.ConfigDict()
config.model.decoder.t5_decoder.logits_via_embedding = False
config.model.decoder.t5_decoder.dropout_rate = 0.1
config.model.decoder.t5_decoder.num_frames = num_frames
config.model.decoder.notime = notime
config.model.decoder.num_bins = num_bins
config.model.decoder.tmp_only = tmp_only
config.model.decoder.t5_decoder.pretrained_config = 't5_1_1_base'
# Initalisation configs
config.init_from = ml_collections.ConfigDict()
# Replace with your checkpoint pretrained on YT-temporal-1bn, assuming it has
# been trained for 200K iterations
config.init_from.checkpoint_path = '/mnt/petrelfs/wangyiqin/vid_cap/vid2seq_model'
# config.init_from.model_config = '/mnt/petrelfs/wangyiqin/vid_cap/scenic/scenic/projects/vid2seq/configs/yttemporal.py'
config.init_from.step = 200001 # ytt 200000, anet 200001
config.init_from.encoder = ml_collections.ConfigDict()
config.init_from.encoder.checkpoint_path = None
config.init_from.encoder.init_from_vit = False
config.init_from.encoder = ml_collections.ConfigDict()
config.init_from.encoder.load_pretrained_weights = True
config.init_from.decoder = ml_collections.ConfigDict()
config.init_from.decoder.load_pretrained_weights = True
config.init_from.t5 = ml_collections.ConfigDict()
config.init_from.t5.load_pretrained_weights = True
# Training
config.trainer_name = 'densevidcap_trainer'
config.optimizer = 'adam'
config.optimizer_configs = ml_collections.ConfigDict()
config.optimizer_configs.weight_decay = 0.
config.l2_decay_factor = 0.
config.max_grad_norm = 1.
config.label_smoothing = 0.1
epochs = ml_collections.config_dict.FieldReference(0) ### add
config.num_training_epochs = 0
batch_size = ml_collections.config_dict.FieldReference(1)
config.batch_size = 1 #if runlocal else batch_size # 128 # Minimum is num_devices = 32
config.eval_batch_size = 1 #if runlocal else 32 # Needs to be num_local_devices
config.rng_seed = 0
# Learning schedule.
steps_per_epoch = 3 if runlocal else YOUCOOK_TRAIN_SIZE // batch_size
total_steps = epochs * steps_per_epoch
config.lr_configs = ml_collections.ConfigDict()
config.lr_configs.learning_rate_schedule = 'compound'
config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
config.lr_configs.warmup_steps = total_steps // 10
config.lr_configs.steps_per_cycle = total_steps
config.lr_configs.total_steps = total_steps
config.lr_configs.base_learning_rate = 3e-4
config.eval_metrics = ['cider', 'meteor', 'soda']
# Logging
config.log_eval_steps = steps_per_epoch # write TB and/or XM summary
config.log_summary_steps = steps_per_epoch # write TB and/or XM summary
config.write_summary = True # write TB and/or XM summary
config.write_xm_measurements = True # write XM measurements
config.xprof = True # Profile using xprof
config.checkpoint = True # do checkpointing
config.debug_train = False # debug mode during training
config.debug_eval = True # debug mode during eval
return config
|