nielklug's picture
init
6ed21b9
import argparse
from benepar import nkutil
def make_hparams():
return nkutil.HParams(
# Data processing
max_len_train=0, # no length limit
max_len_dev=0, # no length limit
# Optimization
batch_size=32,
learning_rate=0.00005,
learning_rate_warmup_steps=160,
clip_grad_norm=0.0, # no clipping
checks_per_epoch=4,
step_decay_factor=0.5,
step_decay_patience=5,
max_consecutive_decays=3, # establishes a termination criterion
# CharLSTM
use_chars_lstm=False,
d_char_emb=64,
char_lstm_input_dropout=0.2,
# BERT and other pre-trained models
use_pretrained=False,
pretrained_model="bert-base-uncased",
# Partitioned transformer encoder
use_encoder=False,
d_model=1024,
num_layers=8,
num_heads=8,
d_kv=64,
d_ff=2048,
encoder_max_len=512,
# Dropout
morpho_emb_dropout=0.2,
attention_dropout=0.2,
relu_dropout=0.1,
residual_dropout=0.2,
# Output heads and losses
force_root_constituent="auto",
predict_tags=False,
d_label_hidden=256,
d_tag_hidden=256,
tag_loss_scale=5.0,
)
def run_train(args, hparams):
print("Train:")
print("dimension of attention key value: {}".format(hparams.d_kv))
print("use pretrained: {}".format(args.use_pretrained))
print("text processing mode: {}".format(args.text_processing))
def run_test(args):
print("Test:")
print("text processing mode: {}".format(args.text_processing))
def main():
print("running...")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
hparams = make_hparams()
subparser = subparsers.add_parser('train')
subparser.set_defaults(callback=lambda args: run_train(args, hparams))
hparams.populate_arguments(subparser)
subparser.add_argument("--text-processing", default='default')
subparser = subparsers.add_parser("test")
subparser.set_defaults(callback=run_test)
subparser.add_argument("--test", required=True)
subparser.add_argument("--text-processing", default='default')
args = parser.parse_args()
print(args.__dict__)
args.callback(args)
if __name__ == '__main__':
main()