import argparse from benepar import nkutil def make_hparams(): return nkutil.HParams( # Data processing max_len_train=0, # no length limit max_len_dev=0, # no length limit # Optimization batch_size=32, learning_rate=0.00005, learning_rate_warmup_steps=160, clip_grad_norm=0.0, # no clipping checks_per_epoch=4, step_decay_factor=0.5, step_decay_patience=5, max_consecutive_decays=3, # establishes a termination criterion # CharLSTM use_chars_lstm=False, d_char_emb=64, char_lstm_input_dropout=0.2, # BERT and other pre-trained models use_pretrained=False, pretrained_model="bert-base-uncased", # Partitioned transformer encoder use_encoder=False, d_model=1024, num_layers=8, num_heads=8, d_kv=64, d_ff=2048, encoder_max_len=512, # Dropout morpho_emb_dropout=0.2, attention_dropout=0.2, relu_dropout=0.1, residual_dropout=0.2, # Output heads and losses force_root_constituent="auto", predict_tags=False, d_label_hidden=256, d_tag_hidden=256, tag_loss_scale=5.0, ) def run_train(args, hparams): print("Train:") print("dimension of attention key value: {}".format(hparams.d_kv)) print("use pretrained: {}".format(args.use_pretrained)) print("text processing mode: {}".format(args.text_processing)) def run_test(args): print("Test:") print("text processing mode: {}".format(args.text_processing)) def main(): print("running...") parser = argparse.ArgumentParser() subparsers = parser.add_subparsers() hparams = make_hparams() subparser = subparsers.add_parser('train') subparser.set_defaults(callback=lambda args: run_train(args, hparams)) hparams.populate_arguments(subparser) subparser.add_argument("--text-processing", default='default') subparser = subparsers.add_parser("test") subparser.set_defaults(callback=run_test) subparser.add_argument("--test", required=True) subparser.add_argument("--text-processing", default='default') args = parser.parse_args() print(args.__dict__) args.callback(args) if __name__ == '__main__': main()