|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Misc for Transformer.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
|
|
from absl import flags |
|
import tensorflow as tf |
|
|
|
from official.nlp.transformer import model_params |
|
from official.utils.flags import core as flags_core |
|
from official.utils.misc import keras_utils |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
PARAMS_MAP = { |
|
'tiny': model_params.TINY_PARAMS, |
|
'base': model_params.BASE_PARAMS, |
|
'big': model_params.BIG_PARAMS, |
|
} |
|
|
|
|
|
def get_model_params(param_set, num_gpus): |
|
"""Gets predefined model params.""" |
|
if num_gpus > 1: |
|
if param_set == 'big': |
|
return model_params.BIG_MULTI_GPU_PARAMS.copy() |
|
elif param_set == 'base': |
|
return model_params.BASE_MULTI_GPU_PARAMS.copy() |
|
else: |
|
raise ValueError('Not valid params: param_set={} num_gpus={}'.format( |
|
param_set, num_gpus)) |
|
|
|
return PARAMS_MAP[param_set].copy() |
|
|
|
|
|
def define_transformer_flags(): |
|
"""Add flags and flag validators for running transformer_main.""" |
|
|
|
flags_core.define_base(num_gpu=True, distribution_strategy=True) |
|
flags_core.define_performance( |
|
num_parallel_calls=True, |
|
inter_op=False, |
|
intra_op=False, |
|
synthetic_data=True, |
|
max_train_steps=False, |
|
dtype=True, |
|
loss_scale=True, |
|
all_reduce_alg=True, |
|
num_packs=True, |
|
tf_gpu_thread_mode=True, |
|
datasets_num_private_threads=True, |
|
enable_xla=True, |
|
fp16_implementation=True |
|
) |
|
|
|
flags_core.define_benchmark() |
|
flags_core.define_device(tpu=True) |
|
|
|
flags.DEFINE_integer( |
|
name='train_steps', short_name='ts', default=300000, |
|
help=flags_core.help_wrap('The number of steps used to train.')) |
|
flags.DEFINE_integer( |
|
name='steps_between_evals', short_name='sbe', default=5000, |
|
help=flags_core.help_wrap( |
|
'The Number of training steps to run between evaluations. This is ' |
|
'used if --train_steps is defined.')) |
|
flags.DEFINE_boolean( |
|
name='enable_time_history', default=True, |
|
help='Whether to enable TimeHistory callback.') |
|
flags.DEFINE_boolean( |
|
name='enable_tensorboard', default=False, |
|
help='Whether to enable Tensorboard callback.') |
|
flags.DEFINE_boolean( |
|
name='enable_metrics_in_training', default=False, |
|
help='Whether to enable metrics during training.') |
|
flags.DEFINE_boolean( |
|
name='enable_mlir_bridge', |
|
default=False, |
|
help='Whether to enable the TF to XLA bridge.') |
|
|
|
|
|
|
|
flags.adopt_module_key_flags(flags_core) |
|
|
|
|
|
flags.DEFINE_enum( |
|
name='param_set', short_name='mp', default='big', |
|
enum_values=PARAMS_MAP.keys(), |
|
help=flags_core.help_wrap( |
|
'Parameter set to use when creating and training the model. The ' |
|
'parameters define the input shape (batch size and max length), ' |
|
'model configuration (size of embedding, # of hidden layers, etc.), ' |
|
'and various other settings. The big parameter set increases the ' |
|
'default batch size, embedding/hidden size, and filter size. For a ' |
|
'complete list of parameters, please see model/model_params.py.')) |
|
|
|
flags.DEFINE_bool( |
|
name='static_batch', short_name='sb', default=False, |
|
help=flags_core.help_wrap( |
|
'Whether the batches in the dataset should have static shapes. In ' |
|
'general, this setting should be False. Dynamic shapes allow the ' |
|
'inputs to be grouped so that the number of padding tokens is ' |
|
'minimized, and helps model training. In cases where the input shape ' |
|
'must be static (e.g. running on TPU), this setting will be ignored ' |
|
'and static batching will always be used.')) |
|
flags.DEFINE_integer( |
|
name='max_length', short_name='ml', default=256, |
|
help=flags_core.help_wrap( |
|
'Max sentence length for Transformer. Default is 256. Note: Usually ' |
|
'it is more effective to use a smaller max length if static_batch is ' |
|
'enabled, e.g. 64.')) |
|
|
|
|
|
flags.DEFINE_integer( |
|
name='validation_steps', short_name='vs', default=64, |
|
help=flags_core.help_wrap('The number of steps used in validation.')) |
|
|
|
|
|
flags.DEFINE_string( |
|
name='bleu_source', short_name='bls', default=None, |
|
help=flags_core.help_wrap( |
|
'Path to source file containing text translate when calculating the ' |
|
'official BLEU score. Both --bleu_source and --bleu_ref must be set. ' |
|
)) |
|
flags.DEFINE_string( |
|
name='bleu_ref', short_name='blr', default=None, |
|
help=flags_core.help_wrap( |
|
'Path to source file containing text translate when calculating the ' |
|
'official BLEU score. Both --bleu_source and --bleu_ref must be set. ' |
|
)) |
|
flags.DEFINE_string( |
|
name='vocab_file', short_name='vf', default=None, |
|
help=flags_core.help_wrap( |
|
'Path to subtoken vocabulary file. If data_download.py was used to ' |
|
'download and encode the training data, look in the data_dir to find ' |
|
'the vocab file.')) |
|
flags.DEFINE_string( |
|
name='mode', default='train', |
|
help=flags_core.help_wrap('mode: train, eval, or predict')) |
|
flags.DEFINE_bool( |
|
name='use_ctl', |
|
default=False, |
|
help=flags_core.help_wrap( |
|
'Whether the model runs with custom training loop.')) |
|
flags.DEFINE_integer( |
|
name='decode_batch_size', |
|
default=32, |
|
help=flags_core.help_wrap( |
|
'Global batch size used for Transformer autoregressive decoding on ' |
|
'TPU.')) |
|
flags.DEFINE_integer( |
|
name='decode_max_length', |
|
default=97, |
|
help=flags_core.help_wrap( |
|
'Max sequence length of the decode/eval data. This is used by ' |
|
'Transformer autoregressive decoding on TPU to have minimum ' |
|
'paddings.')) |
|
flags.DEFINE_bool( |
|
name='padded_decode', |
|
default=False, |
|
help=flags_core.help_wrap( |
|
'Whether the autoregressive decoding runs with input data padded to ' |
|
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be ' |
|
'set due the static shape requirement. Although CPU/GPU could also ' |
|
'use padded_decode, it has not been tested. In addition, this method ' |
|
'will introduce unnecessary overheads which grow quadratically with ' |
|
'the max sequence length.')) |
|
flags.DEFINE_bool( |
|
name='enable_checkpointing', |
|
default=True, |
|
help=flags_core.help_wrap( |
|
'Whether to do checkpointing during training. When running under ' |
|
'benchmark harness, we will avoid checkpointing.')) |
|
|
|
flags_core.set_defaults(data_dir='/tmp/translate_ende', |
|
model_dir='/tmp/transformer_model', |
|
batch_size=None) |
|
|
|
|
|
@flags.multi_flags_validator( |
|
['bleu_source', 'bleu_ref'], |
|
message='Both or neither --bleu_source and --bleu_ref must be defined.') |
|
def _check_bleu_files(flags_dict): |
|
return (flags_dict['bleu_source'] is None) == ( |
|
flags_dict['bleu_ref'] is None) |
|
|
|
@flags.multi_flags_validator( |
|
['bleu_source', 'bleu_ref', 'vocab_file'], |
|
message='--vocab_file must be defined if --bleu_source and --bleu_ref ' |
|
'are defined.') |
|
def _check_bleu_vocab_file(flags_dict): |
|
if flags_dict['bleu_source'] and flags_dict['bleu_ref']: |
|
return flags_dict['vocab_file'] is not None |
|
return True |
|
|
|
|
|
|
|
def get_callbacks(): |
|
"""Returns common callbacks.""" |
|
callbacks = [] |
|
if FLAGS.enable_time_history: |
|
time_callback = keras_utils.TimeHistory( |
|
FLAGS.batch_size, |
|
FLAGS.log_steps, |
|
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None) |
|
callbacks.append(time_callback) |
|
|
|
if FLAGS.enable_tensorboard: |
|
tensorboard_callback = tf.keras.callbacks.TensorBoard( |
|
log_dir=FLAGS.model_dir) |
|
callbacks.append(tensorboard_callback) |
|
|
|
return callbacks |
|
|
|
|
|
def update_stats(history, stats, callbacks): |
|
"""Normalizes and updates dictionary of stats. |
|
|
|
Args: |
|
history: Results of the training step. |
|
stats: Dict with pre-existing training stats. |
|
callbacks: a list of callbacks which might include a time history callback |
|
used during keras.fit. |
|
""" |
|
|
|
if history and history.history: |
|
train_hist = history.history |
|
|
|
stats['loss'] = float(train_hist['loss'][-1]) |
|
|
|
if not callbacks: |
|
return |
|
|
|
|
|
for callback in callbacks: |
|
if isinstance(callback, keras_utils.TimeHistory): |
|
timestamp_log = callback.timestamp_log |
|
stats['step_timestamp_log'] = timestamp_log |
|
stats['train_finish_time'] = callback.train_finish_time |
|
if len(timestamp_log) > 1: |
|
stats['avg_exp_per_second'] = ( |
|
callback.batch_size * callback.log_steps * |
|
(len(callback.timestamp_log)-1) / |
|
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp)) |
|
|