from __future__ import print_function import argparse def if_use_feat(caption_model): # Decide if load attention feature according to caption model if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']: use_att, use_fc = False, True elif caption_model == 'language_model': use_att, use_fc = False, False elif caption_model in ['updown', 'topdown']: use_fc, use_att = True, True else: use_att, use_fc = True, False return use_fc, use_att def parse_opt(): parser = argparse.ArgumentParser() # Data input settings parser.add_argument('--input_json', type=str, default='data/coco.json', help='path to the json file containing additional info and vocab') parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', help='path to the directory containing the preprocessed fc feats') parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', help='path to the directory containing the preprocessed att feats') parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box', help='path to the directory containing the boxes of att feats') parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', help='path to the h5file containing the preprocessed dataset') parser.add_argument('--data_in_memory', action='store_true', help='True if we want to save the features in memory') parser.add_argument('--start_from', type=str, default=None, help="""continue training from saved model at this path. Path must contain files saved by previous training process: 'infos.pkl' : configuration; 'model.pth' : weights """) parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', help='Cached token file for calculating cider score during self critical training.') # Model settings parser.add_argument('--caption_model', type=str, default="show_tell", help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer') parser.add_argument('--rnn_size', type=int, default=512, help='size of the rnn in number of hidden nodes in each layer') parser.add_argument('--num_layers', type=int, default=1, help='number of layers in the RNN') parser.add_argument('--rnn_type', type=str, default='lstm', help='rnn, gru, or lstm') parser.add_argument('--input_encoding_size', type=int, default=512, help='the encoding size of each token in the vocabulary, and the image.') parser.add_argument('--att_hid_size', type=int, default=512, help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') parser.add_argument('--fc_feat_size', type=int, default=2048, help='2048 for resnet, 4096 for vgg') parser.add_argument('--att_feat_size', type=int, default=2048, help='2048 for resnet, 512 for vgg') parser.add_argument('--logit_layers', type=int, default=1, help='number of layers in the RNN') parser.add_argument('--use_bn', type=int, default=0, help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed') # feature manipulation parser.add_argument('--norm_att_feat', type=int, default=0, help='If normalize attention features') parser.add_argument('--use_box', type=int, default=0, help='If use box features') parser.add_argument('--norm_box_feat', type=int, default=0, help='If use box, do we normalize box feature') # Optimization: General parser.add_argument('--max_epochs', type=int, default=-1, help='number of epochs') parser.add_argument('--batch_size', type=int, default=16, help='minibatch size') parser.add_argument('--grad_clip_mode', type=str, default='value', help='value or norm') parser.add_argument('--grad_clip_value', type=float, default=0.1, help='clip gradients at this value/max_norm, 0 means no clipping') parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='strength of dropout in the Language Model RNN') parser.add_argument('--self_critical_after', type=int, default=-1, help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') parser.add_argument('--seq_per_img', type=int, default=5, help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') # Sample related add_eval_sample_opts(parser) #Optimization: for the Language Model parser.add_argument('--optim', type=str, default='adam', help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw') parser.add_argument('--learning_rate', type=float, default=4e-4, help='learning rate') parser.add_argument('--learning_rate_decay_start', type=int, default=-1, help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') parser.add_argument('--learning_rate_decay_every', type=int, default=3, help='every how many iterations thereafter to drop LR?(in epoch)') parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, help='every how many iterations thereafter to drop LR?(in epoch)') parser.add_argument('--optim_alpha', type=float, default=0.9, help='alpha for adam') parser.add_argument('--optim_beta', type=float, default=0.999, help='beta used for adam') parser.add_argument('--optim_epsilon', type=float, default=1e-8, help='epsilon that goes into denominator for smoothing') parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') # Transformer parser.add_argument('--label_smoothing', type=float, default=0, help='') parser.add_argument('--noamopt', action='store_true', help='') parser.add_argument('--noamopt_warmup', type=int, default=2000, help='') parser.add_argument('--noamopt_factor', type=float, default=1, help='') parser.add_argument('--reduce_on_plateau', action='store_true', help='') parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5, help='') parser.add_argument('--reduce_on_plateau_patience', type=int, default=3, help='') parser.add_argument('--cached_transformer', action='store_true', help='') parser.add_argument('--use_warmup', action='store_true', help='warm up the learing rate?') parser.add_argument('--scheduled_sampling_start', type=int, default=-1, help='at what iteration to start decay gt probability') parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, help='every how many iterations thereafter to gt probability') parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, help='How much to update the prob') parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, help='Maximum scheduled sampling prob.') # Evaluation/Checkpointing parser.add_argument('--val_images_use', type=int, default=3200, help='how many images to use when periodically evaluating the validation loss? (-1 = all)') parser.add_argument('--save_checkpoint_every', type=int, default=2500, help='how often to save a model checkpoint (in iterations)?') parser.add_argument('--save_every_epoch', action='store_true', help='Save checkpoint every epoch, will overwrite save_checkpoint_every') parser.add_argument('--save_history_ckpt', type=int, default=0, help='If save checkpoints at every save point') parser.add_argument('--checkpoint_path', type=str, default=None, help='directory to store checkpointed models') parser.add_argument('--language_eval', type=int, default=0, help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') parser.add_argument('--losses_log_every', type=int, default=25, help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') parser.add_argument('--load_best_score', type=int, default=1, help='Do we load previous best score when resuming training.') # misc parser.add_argument('--id', type=str, default='', help='an id identifying this run/job. used in cross-val and appended when writing progress files') parser.add_argument('--train_only', type=int, default=0, help='if true then use 80k, else use 110k') parser.add_argument('--topic', type=str, default='dress', help='type of datasets, such as dress, shirt, toptee') # Reward parser.add_argument('--cider_reward_weight', type=float, default=1, help='The reward weight from cider') parser.add_argument('--bleu_reward_weight', type=float, default=0, help='The reward weight from bleu4') # Structure_loss parser.add_argument('--structure_loss_weight', type=float, default=1, help='') parser.add_argument('--structure_after', type=int, default=-1, help='T') parser.add_argument('--structure_loss_type', type=str, default='seqnll', help='') parser.add_argument('--struc_use_logsoftmax', action='store_true', help='') parser.add_argument('--entropy_reward_weight', type=float, default=0, help='Entropy reward, seems very interesting') parser.add_argument('--self_cider_reward_weight', type=float, default=0, help='self cider reward') # Used for self critical or structure. Used when sampling is need during training parser.add_argument('--train_sample_n', type=int, default=1, help='The reward weight from cider') parser.add_argument('--train_sample_method', type=str, default='sample', help='') parser.add_argument('--train_beam_size', type=int, default=1, help='') # Used for self critical parser.add_argument('--sc_sample_method', type=str, default='greedy', help='') parser.add_argument('--sc_beam_size', type=int, default=1, help='') parser.add_argument('--seed', type=int, default=42, help='') # For diversity evaluation during training add_diversity_opts(parser) # config parser.add_argument('--cfg', type=str, default=None, help='configuration; similar to what is used in detectron') parser.add_argument( '--set_cfgs', dest='set_cfgs', help='Set config keys. Key value sequence seperate by whitespace.' 'e.g. [key] [value] [key] [value]\n This has higher priority' 'than cfg file but lower than other args. (You can only overwrite' 'arguments that have alerady been defined in config file.)', default=[], nargs='+') # How will config be used # 1) read cfg argument, and load the cfg file if it's not None # 2) Overwrite cfg argument with set_cfgs # 3) parse config argument to args. # 4) in the end, parse command line argument and overwrite args # step 1: read cfg_fn args = parser.parse_args() if args.cfg is not None or args.set_cfgs is not None: from .config import CfgNode if args.cfg is not None: cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg)) else: cn = CfgNode() if args.set_cfgs is not None: cn.merge_from_list(args.set_cfgs) for k,v in cn.items(): if not hasattr(args, k): print('Warning: key %s not in args' %k) setattr(args, k, v) args = parser.parse_args(namespace=args) # Check if args are valid assert args.rnn_size > 0, "rnn_size should be greater than 0" assert args.num_layers > 0, "num_layers should be greater than 0" assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" assert args.batch_size > 0, "batch_size should be greater than 0" assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" assert args.seq_per_img > 0, "seq_per_img should be greater than 0" assert args.beam_size > 0, "beam_size should be greater than 0" assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" assert args.losses_log_every > 0, "losses_log_every should be greater than 0" assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" # default value for start_from and checkpoint_path # args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id args.checkpoint_path = args.checkpoint_path or './results/log_{}_{}'.format(args.topic, args.id) args.start_from = args.start_from or args.checkpoint_path # Deal with feature things before anything args.use_fc, args.use_att = if_use_feat(args.caption_model) if args.use_box: args.att_feat_size = args.att_feat_size + 5 return args def add_eval_options(parser): # Basic options parser.add_argument('--batch_size', type=int, default=0, help='if > 0 then overrule, otherwise load from checkpoint.') parser.add_argument('--num_images', type=int, default=-1, help='how many images to use when periodically evaluating the loss? (-1 = all)') parser.add_argument('--language_eval', type=int, default=0, help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') parser.add_argument('--dump_images', type=int, default=1, help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') parser.add_argument('--dump_json', type=int, default=1, help='Dump json with predictions into vis folder? (1=yes,0=no)') parser.add_argument('--dump_path', type=int, default=0, help='Write image paths along with predictions into vis json? (1=yes,0=no)') # Sampling options add_eval_sample_opts(parser) # For evaluation on a folder of images: parser.add_argument('--image_folder', type=str, default='', help='If this is nonempty then will predict on the images in this folder path') parser.add_argument('--image_root', type=str, default='', help='In case the image paths have to be preprended with a root path to an image folder') # For evaluation on MSCOCO images from some split: parser.add_argument('--input_fc_dir', type=str, default='', help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_att_dir', type=str, default='', help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_box_dir', type=str, default='', help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_label_h5', type=str, default='', help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_json', type=str, default='', help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') parser.add_argument('--split', type=str, default='test', help='if running on MSCOCO images, which split to use: val|test|train') parser.add_argument('--coco_json', type=str, default='', help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') # misc parser.add_argument('--id', type=str, default='', help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') parser.add_argument('--verbose_beam', type=int, default=1, help='if we need to print out all beam search beams.') parser.add_argument('--verbose_loss', type=int, default=0, help='If calculate loss using ground truth during evaluation') parser.add_argument('--seed', type=int, default=42, help='') def add_diversity_opts(parser): parser.add_argument('--sample_n', type=int, default=1, help='Diverse sampling') parser.add_argument('--sample_n_method', type=str, default='sample', help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp') parser.add_argument('--eval_oracle', type=int, default=1, help='if we need to calculate loss.') # Sampling related options def add_eval_sample_opts(parser): parser.add_argument('--sample_method', type=str, default='greedy', help='greedy; sample; gumbel; top, top<0-1>') parser.add_argument('--beam_size', type=int, default=1, help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') parser.add_argument('--max_length', type=int, default=8, help='Maximum length during sampling') parser.add_argument('--length_penalty', type=str, default='', help='wu_X or avg_X, X is the alpha') parser.add_argument('--group_size', type=int, default=1, help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') parser.add_argument('--diversity_lambda', type=float, default=0.5, help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') parser.add_argument('--temperature', type=float, default=1.0, help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.') parser.add_argument('--decoding_constraint', type=int, default=0, help='If 1, not allowing same word in a row') parser.add_argument('--block_trigrams', type=int, default=0, help='block repeated trigram.') parser.add_argument('--remove_bad_endings', type=int, default=1, help='Remove bad endings') parser.add_argument('--suppress_UNK', type=int, default=1, help='Not predicting UNK') if __name__ == '__main__': import sys sys.argv = [sys.argv[0]] args = parse_opt() print(args) print() sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml'] args1 = parse_opt() print(dict(set(vars(args1).items()) - set(vars(args).items()))) print() sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2'] args2 = parse_opt() print(dict(set(vars(args2).items()) - set(vars(args1).items())))