Spaces:
Sleeping
Sleeping
import argparse | |
from benepar import nkutil | |
def make_hparams(): | |
return nkutil.HParams( | |
# Data processing | |
max_len_train=0, # no length limit | |
max_len_dev=0, # no length limit | |
# Optimization | |
batch_size=32, | |
learning_rate=0.00005, | |
learning_rate_warmup_steps=160, | |
clip_grad_norm=0.0, # no clipping | |
checks_per_epoch=4, | |
step_decay_factor=0.5, | |
step_decay_patience=5, | |
max_consecutive_decays=3, # establishes a termination criterion | |
# CharLSTM | |
use_chars_lstm=False, | |
d_char_emb=64, | |
char_lstm_input_dropout=0.2, | |
# BERT and other pre-trained models | |
use_pretrained=False, | |
pretrained_model="bert-base-uncased", | |
# Partitioned transformer encoder | |
use_encoder=False, | |
d_model=1024, | |
num_layers=8, | |
num_heads=8, | |
d_kv=64, | |
d_ff=2048, | |
encoder_max_len=512, | |
# Dropout | |
morpho_emb_dropout=0.2, | |
attention_dropout=0.2, | |
relu_dropout=0.1, | |
residual_dropout=0.2, | |
# Output heads and losses | |
force_root_constituent="auto", | |
predict_tags=False, | |
d_label_hidden=256, | |
d_tag_hidden=256, | |
tag_loss_scale=5.0, | |
) | |
def run_train(args, hparams): | |
print("Train:") | |
print("dimension of attention key value: {}".format(hparams.d_kv)) | |
print("use pretrained: {}".format(args.use_pretrained)) | |
print("text processing mode: {}".format(args.text_processing)) | |
def run_test(args): | |
print("Test:") | |
print("text processing mode: {}".format(args.text_processing)) | |
def main(): | |
print("running...") | |
parser = argparse.ArgumentParser() | |
subparsers = parser.add_subparsers() | |
hparams = make_hparams() | |
subparser = subparsers.add_parser('train') | |
subparser.set_defaults(callback=lambda args: run_train(args, hparams)) | |
hparams.populate_arguments(subparser) | |
subparser.add_argument("--text-processing", default='default') | |
subparser = subparsers.add_parser("test") | |
subparser.set_defaults(callback=run_test) | |
subparser.add_argument("--test", required=True) | |
subparser.add_argument("--text-processing", default='default') | |
args = parser.parse_args() | |
print(args.__dict__) | |
args.callback(args) | |
if __name__ == '__main__': | |
main() |