|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Register flags for optimizing performance.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import multiprocessing |
|
|
|
from absl import flags |
|
import tensorflow as tf |
|
|
|
from official.utils.flags._conventions import help_wrap |
|
|
|
|
|
|
|
DTYPE_MAP = { |
|
"fp16": tf.float16, |
|
"bf16": tf.bfloat16, |
|
"fp32": tf.float32, |
|
} |
|
|
|
|
|
def get_tf_dtype(flags_obj): |
|
if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite": |
|
|
|
|
|
return tf.float32 |
|
return DTYPE_MAP[flags_obj.dtype] |
|
|
|
|
|
def get_loss_scale(flags_obj, default_for_fp16): |
|
dtype = get_tf_dtype(flags_obj) |
|
if flags_obj.loss_scale == "dynamic": |
|
return flags_obj.loss_scale |
|
elif flags_obj.loss_scale is not None: |
|
return float(flags_obj.loss_scale) |
|
elif dtype == tf.float32 or dtype == tf.bfloat16: |
|
return 1 |
|
else: |
|
assert dtype == tf.float16 |
|
return default_for_fp16 |
|
|
|
|
|
def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False, |
|
synthetic_data=False, max_train_steps=False, dtype=False, |
|
all_reduce_alg=False, num_packs=False, |
|
tf_gpu_thread_mode=False, |
|
datasets_num_private_threads=False, |
|
datasets_num_parallel_batches=False, |
|
dynamic_loss_scale=False, fp16_implementation=False, |
|
loss_scale=False, |
|
tf_data_experimental_slack=False, enable_xla=False, |
|
training_dataset_cache=False): |
|
"""Register flags for specifying performance tuning arguments. |
|
|
|
Args: |
|
num_parallel_calls: Create a flag to specify parallelism of data loading. |
|
inter_op: Create a flag to allow specification of inter op threads. |
|
intra_op: Create a flag to allow specification of intra op threads. |
|
synthetic_data: Create a flag to allow the use of synthetic data. |
|
max_train_steps: Create a flags to allow specification of maximum number |
|
of training steps |
|
dtype: Create flags for specifying dtype. |
|
all_reduce_alg: If set forces a specific algorithm for multi-gpu. |
|
num_packs: If set provides number of packs for MirroredStrategy's cross |
|
device ops. |
|
tf_gpu_thread_mode: gpu_private triggers us of private thread pool. |
|
datasets_num_private_threads: Number of private threads for datasets. |
|
datasets_num_parallel_batches: Determines how many batches to process in |
|
parallel when using map and batch from tf.data. |
|
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value |
|
"dynamic". Only valid if `dtype` is True. |
|
fp16_implementation: Create fp16_implementation flag. |
|
loss_scale: Controls the loss scaling, normally for mixed-precision |
|
training. Can only be turned on if dtype is also True. |
|
tf_data_experimental_slack: Determines whether to enable tf.data's |
|
`experimental_slack` option. |
|
enable_xla: Determines if XLA (auto clustering) is turned on. |
|
training_dataset_cache: Whether to cache the training dataset on workers. |
|
Typically used to improve training performance when training data is in |
|
remote storage and can fit into worker memory. |
|
|
|
Returns: |
|
A list of flags for core.py to marks as key flags. |
|
""" |
|
|
|
key_flags = [] |
|
if num_parallel_calls: |
|
flags.DEFINE_integer( |
|
name="num_parallel_calls", short_name="npc", |
|
default=multiprocessing.cpu_count(), |
|
help=help_wrap("The number of records that are processed in parallel " |
|
"during input processing. This can be optimized per " |
|
"data set but for generally homogeneous data sets, " |
|
"should be approximately the number of available CPU " |
|
"cores. (default behavior)")) |
|
|
|
if inter_op: |
|
flags.DEFINE_integer( |
|
name="inter_op_parallelism_threads", short_name="inter", default=0, |
|
help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. " |
|
"See TensorFlow config.proto for details.") |
|
) |
|
|
|
if intra_op: |
|
flags.DEFINE_integer( |
|
name="intra_op_parallelism_threads", short_name="intra", default=0, |
|
help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. " |
|
"See TensorFlow config.proto for details.")) |
|
|
|
if synthetic_data: |
|
flags.DEFINE_bool( |
|
name="use_synthetic_data", short_name="synth", default=False, |
|
help=help_wrap( |
|
"If set, use fake data (zeroes) instead of a real dataset. " |
|
"This mode is useful for performance debugging, as it removes " |
|
"input processing steps, but will not learn anything.")) |
|
|
|
if max_train_steps: |
|
flags.DEFINE_integer( |
|
name="max_train_steps", short_name="mts", default=None, help=help_wrap( |
|
"The model will stop training if the global_step reaches this " |
|
"value. If not set, training will run until the specified number " |
|
"of epochs have run as usual. It is generally recommended to set " |
|
"--train_epochs=1 when using this flag." |
|
)) |
|
|
|
if dtype: |
|
flags.DEFINE_enum( |
|
name="dtype", short_name="dt", default="fp32", |
|
enum_values=DTYPE_MAP.keys(), |
|
help=help_wrap("The TensorFlow datatype used for calculations. " |
|
"Variables may be cast to a higher precision on a " |
|
"case-by-case basis for numerical stability.")) |
|
|
|
loss_scale_help_text = ( |
|
"The amount to scale the loss by when the model is run. {}. Before " |
|
"gradients are computed, the loss is multiplied by the loss scale, " |
|
"making all gradients loss_scale times larger. To adjust for this, " |
|
"gradients are divided by the loss scale before being applied to " |
|
"variables. This is mathematically equivalent to training without " |
|
"a loss scale, but the loss scale helps avoid some intermediate " |
|
"gradients from underflowing to zero. If not provided the default " |
|
"for fp16 is 128 and 1 for all other dtypes.{}" |
|
) |
|
if dynamic_loss_scale: |
|
loss_scale_help_text = loss_scale_help_text.format( |
|
"This can be an int/float or the string 'dynamic'", |
|
" The string 'dynamic' can be used to dynamically determine the " |
|
"optimal loss scale during training, but currently this " |
|
"significantly slows down performance") |
|
loss_scale_validation_msg = ("loss_scale should be a positive int/float " |
|
"or the string 'dynamic'.") |
|
else: |
|
loss_scale_help_text = loss_scale_help_text.format( |
|
"This must be an int/float", "") |
|
loss_scale_validation_msg = "loss_scale should be a positive int/float." |
|
if loss_scale: |
|
flags.DEFINE_string( |
|
name="loss_scale", short_name="ls", default=None, |
|
help=help_wrap(loss_scale_help_text)) |
|
|
|
@flags.validator(flag_name="loss_scale", |
|
message=loss_scale_validation_msg) |
|
def _check_loss_scale(loss_scale): |
|
"""Validator to check the loss scale flag is valid.""" |
|
if loss_scale is None: |
|
return True |
|
|
|
if loss_scale == "dynamic" and dynamic_loss_scale: |
|
return True |
|
|
|
try: |
|
loss_scale = float(loss_scale) |
|
except ValueError: |
|
return False |
|
|
|
return loss_scale > 0 |
|
|
|
if fp16_implementation: |
|
flags.DEFINE_enum( |
|
name="fp16_implementation", default="keras", |
|
enum_values=("keras', 'graph_rewrite"), |
|
help=help_wrap( |
|
"When --dtype=fp16, how fp16 should be implemented. This has no " |
|
"impact on correctness. 'keras' uses the " |
|
"tf.keras.mixed_precision API. 'graph_rewrite' uses the " |
|
"tf.train.experimental.enable_mixed_precision_graph_rewrite " |
|
"API.")) |
|
|
|
@flags.multi_flags_validator(["fp16_implementation", "dtype", |
|
"loss_scale"]) |
|
def _check_fp16_implementation(flags_dict): |
|
"""Validator to check fp16_implementation flag is valid.""" |
|
if (flags_dict["fp16_implementation"] == "graph_rewrite" and |
|
flags_dict["dtype"] != "fp16"): |
|
raise flags.ValidationError("--fp16_implementation should not be " |
|
"specified unless --dtype=fp16") |
|
return True |
|
|
|
if all_reduce_alg: |
|
flags.DEFINE_string( |
|
name="all_reduce_alg", short_name="ara", default=None, |
|
help=help_wrap("Defines the algorithm to use for performing all-reduce." |
|
"When specified with MirroredStrategy for single " |
|
"worker, this controls " |
|
"tf.contrib.distribute.AllReduceCrossTowerOps. When " |
|
"specified with MultiWorkerMirroredStrategy, this " |
|
"controls " |
|
"tf.distribute.experimental.CollectiveCommunication; " |
|
"valid options are `ring` and `nccl`.")) |
|
|
|
if num_packs: |
|
flags.DEFINE_integer( |
|
name="num_packs", default=1, |
|
help=help_wrap("Sets `num_packs` in the cross device ops used in " |
|
"MirroredStrategy. For details, see " |
|
"tf.distribute.NcclAllReduce.")) |
|
|
|
if tf_gpu_thread_mode: |
|
flags.DEFINE_string( |
|
name="tf_gpu_thread_mode", short_name="gt_mode", default=None, |
|
help=help_wrap( |
|
"Whether and how the GPU device uses its own threadpool.") |
|
) |
|
|
|
flags.DEFINE_integer( |
|
name="per_gpu_thread_count", short_name="pgtc", default=0, |
|
help=help_wrap( |
|
"The number of threads to use for GPU. Only valid when " |
|
"tf_gpu_thread_mode is not global.") |
|
) |
|
|
|
if datasets_num_private_threads: |
|
flags.DEFINE_integer( |
|
name="datasets_num_private_threads", |
|
default=None, |
|
help=help_wrap( |
|
"Number of threads for a private threadpool created for all" |
|
"datasets computation..") |
|
) |
|
|
|
if datasets_num_parallel_batches: |
|
flags.DEFINE_integer( |
|
name="datasets_num_parallel_batches", |
|
default=None, |
|
help=help_wrap( |
|
"Determines how many batches to process in parallel when using " |
|
"map and batch from tf.data.") |
|
) |
|
|
|
if training_dataset_cache: |
|
flags.DEFINE_boolean( |
|
name="training_dataset_cache", |
|
default=False, |
|
help=help_wrap( |
|
"Determines whether to cache the training dataset on workers. " |
|
"Typically used to improve training performance when training " |
|
"data is in remote storage and can fit into worker memory.") |
|
) |
|
|
|
if tf_data_experimental_slack: |
|
flags.DEFINE_boolean( |
|
name="tf_data_experimental_slack", |
|
default=False, |
|
help=help_wrap( |
|
"Whether to enable tf.data's `experimental_slack` option.") |
|
) |
|
|
|
if enable_xla: |
|
flags.DEFINE_boolean( |
|
name="enable_xla", default=False, |
|
help="Whether to enable XLA auto jit compilation") |
|
|
|
return key_flags |
|
|