Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Model architecture/optimization options for DrQA document reader.""" | |
import argparse | |
import logging | |
logger = logging.getLogger(__name__) | |
# Index of arguments concerning the core model architecture | |
MODEL_ARCHITECTURE = { | |
'model_type', 'embedding_dim', 'hidden_size', 'doc_layers', | |
'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge', | |
'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf' | |
} | |
# Index of arguments concerning the model optimizer/training | |
MODEL_OPTIMIZER = { | |
'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay', | |
'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb', | |
'max_len', 'grad_clipping', 'tune_partial' | |
} | |
def str2bool(v): | |
return v.lower() in ('yes', 'true', 't', '1', 'y') | |
def add_model_args(parser): | |
parser.register('type', 'bool', str2bool) | |
# Model architecture | |
model = parser.add_argument_group('DrQA Reader Model Architecture') | |
model.add_argument('--model-type', type=str, default='rnn', | |
help='Model architecture type') | |
model.add_argument('--embedding-dim', type=int, default=300, | |
help='Embedding size if embedding_file is not given') | |
model.add_argument('--hidden-size', type=int, default=128, | |
help='Hidden size of RNN units') | |
model.add_argument('--doc-layers', type=int, default=3, | |
help='Number of encoding layers for document') | |
model.add_argument('--question-layers', type=int, default=3, | |
help='Number of encoding layers for question') | |
model.add_argument('--rnn-type', type=str, default='lstm', | |
help='RNN type: LSTM, GRU, or RNN') | |
# Model specific details | |
detail = parser.add_argument_group('DrQA Reader Model Details') | |
detail.add_argument('--concat-rnn-layers', type='bool', default=True, | |
help='Combine hidden states from each encoding layer') | |
detail.add_argument('--question-merge', type=str, default='self_attn', | |
help='The way of computing the question representation') | |
detail.add_argument('--use-qemb', type='bool', default=True, | |
help='Whether to use weighted question embeddings') | |
detail.add_argument('--use-in-question', type='bool', default=True, | |
help='Whether to use in_question_* features') | |
detail.add_argument('--use-pos', type='bool', default=True, | |
help='Whether to use pos features') | |
detail.add_argument('--use-ner', type='bool', default=True, | |
help='Whether to use ner features') | |
detail.add_argument('--use-lemma', type='bool', default=True, | |
help='Whether to use lemma features') | |
detail.add_argument('--use-tf', type='bool', default=True, | |
help='Whether to use term frequency features') | |
# Optimization details | |
optim = parser.add_argument_group('DrQA Reader Optimization') | |
optim.add_argument('--dropout-emb', type=float, default=0.4, | |
help='Dropout rate for word embeddings') | |
optim.add_argument('--dropout-rnn', type=float, default=0.4, | |
help='Dropout rate for RNN states') | |
optim.add_argument('--dropout-rnn-output', type='bool', default=True, | |
help='Whether to dropout the RNN output') | |
optim.add_argument('--optimizer', type=str, default='adamax', | |
help='Optimizer: sgd or adamax') | |
optim.add_argument('--learning-rate', type=float, default=0.1, | |
help='Learning rate for SGD only') | |
optim.add_argument('--grad-clipping', type=float, default=10, | |
help='Gradient clipping') | |
optim.add_argument('--weight-decay', type=float, default=0, | |
help='Weight decay factor') | |
optim.add_argument('--momentum', type=float, default=0, | |
help='Momentum factor') | |
optim.add_argument('--fix-embeddings', type='bool', default=True, | |
help='Keep word embeddings fixed (use pretrained)') | |
optim.add_argument('--tune-partial', type=int, default=0, | |
help='Backprop through only the top N question words') | |
optim.add_argument('--rnn-padding', type='bool', default=False, | |
help='Explicitly account for padding in RNN encoding') | |
optim.add_argument('--max-len', type=int, default=15, | |
help='The max span allowed during decoding') | |
def get_model_args(args): | |
"""Filter args for model ones. | |
From a args Namespace, return a new Namespace with *only* the args specific | |
to the model architecture or optimization. (i.e. the ones defined here.) | |
""" | |
global MODEL_ARCHITECTURE, MODEL_OPTIMIZER | |
required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER | |
arg_values = {k: v for k, v in vars(args).items() if k in required_args} | |
return argparse.Namespace(**arg_values) | |
def override_model_args(old_args, new_args): | |
"""Set args to new parameters. | |
Decide which model args to keep and which to override when resolving a set | |
of saved args and new args. | |
We keep the new optimation, but leave the model architecture alone. | |
""" | |
global MODEL_OPTIMIZER | |
old_args, new_args = vars(old_args), vars(new_args) | |
for k in old_args.keys(): | |
if k in new_args and old_args[k] != new_args[k]: | |
if k in MODEL_OPTIMIZER: | |
logger.info('Overriding saved %s: %s --> %s' % | |
(k, old_args[k], new_args[k])) | |
old_args[k] = new_args[k] | |
else: | |
logger.info('Keeping saved %s: %s' % (k, old_args[k])) | |
return argparse.Namespace(**old_args) | |