NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Register flags for optimizing performance."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
from absl import flags # pylint: disable=g-bad-import-order
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags._conventions import help_wrap
# Map string to TensorFlow dtype
DTYPE_MAP = {
"fp16": tf.float16,
"bf16": tf.bfloat16,
"fp32": tf.float32,
}
def get_tf_dtype(flags_obj):
if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return tf.float32
return DTYPE_MAP[flags_obj.dtype]
def get_loss_scale(flags_obj, default_for_fp16):
dtype = get_tf_dtype(flags_obj)
if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale
elif flags_obj.loss_scale is not None:
return float(flags_obj.loss_scale)
elif dtype == tf.float32 or dtype == tf.bfloat16:
return 1 # No loss scaling is needed for fp32
else:
assert dtype == tf.float16
return default_for_fp16
def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
synthetic_data=False, max_train_steps=False, dtype=False,
all_reduce_alg=False, num_packs=False,
tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
datasets_num_parallel_batches=False,
dynamic_loss_scale=False, fp16_implementation=False,
loss_scale=False,
tf_data_experimental_slack=False, enable_xla=False,
training_dataset_cache=False):
"""Register flags for specifying performance tuning arguments.
Args:
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number
of training steps
dtype: Create flags for specifying dtype.
all_reduce_alg: If set forces a specific algorithm for multi-gpu.
num_packs: If set provides number of packs for MirroredStrategy's cross
device ops.
tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag.
loss_scale: Controls the loss scaling, normally for mixed-precision
training. Can only be turned on if dtype is also True.
tf_data_experimental_slack: Determines whether to enable tf.data's
`experimental_slack` option.
enable_xla: Determines if XLA (auto clustering) is turned on.
training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in
remote storage and can fit into worker memory.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if num_parallel_calls:
flags.DEFINE_integer(
name="num_parallel_calls", short_name="npc",
default=multiprocessing.cpu_count(),
help=help_wrap("The number of records that are processed in parallel "
"during input processing. This can be optimized per "
"data set but for generally homogeneous data sets, "
"should be approximately the number of available CPU "
"cores. (default behavior)"))
if inter_op:
flags.DEFINE_integer(
name="inter_op_parallelism_threads", short_name="inter", default=0,
help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details.")
)
if intra_op:
flags.DEFINE_integer(
name="intra_op_parallelism_threads", short_name="intra", default=0,
help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."))
if synthetic_data:
flags.DEFINE_bool(
name="use_synthetic_data", short_name="synth", default=False,
help=help_wrap(
"If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."))
if max_train_steps:
flags.DEFINE_integer(
name="max_train_steps", short_name="mts", default=None, help=help_wrap(
"The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag."
))
if dtype:
flags.DEFINE_enum(
name="dtype", short_name="dt", default="fp32",
enum_values=DTYPE_MAP.keys(),
help=help_wrap("The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."))
loss_scale_help_text = (
"The amount to scale the loss by when the model is run. {}. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes.{}"
)
if dynamic_loss_scale:
loss_scale_help_text = loss_scale_help_text.format(
"This can be an int/float or the string 'dynamic'",
" The string 'dynamic' can be used to dynamically determine the "
"optimal loss scale during training, but currently this "
"significantly slows down performance")
loss_scale_validation_msg = ("loss_scale should be a positive int/float "
"or the string 'dynamic'.")
else:
loss_scale_help_text = loss_scale_help_text.format(
"This must be an int/float", "")
loss_scale_validation_msg = "loss_scale should be a positive int/float."
if loss_scale:
flags.DEFINE_string(
name="loss_scale", short_name="ls", default=None,
help=help_wrap(loss_scale_help_text))
@flags.validator(flag_name="loss_scale",
message=loss_scale_validation_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
"""Validator to check the loss scale flag is valid."""
if loss_scale is None:
return True # null case is handled in get_loss_scale()
if loss_scale == "dynamic" and dynamic_loss_scale:
return True
try:
loss_scale = float(loss_scale)
except ValueError:
return False
return loss_scale > 0
if fp16_implementation:
flags.DEFINE_enum(
name="fp16_implementation", default="keras",
enum_values=("keras', 'graph_rewrite"),
help=help_wrap(
"When --dtype=fp16, how fp16 should be implemented. This has no "
"impact on correctness. 'keras' uses the "
"tf.keras.mixed_precision API. 'graph_rewrite' uses the "
"tf.train.experimental.enable_mixed_precision_graph_rewrite "
"API."))
@flags.multi_flags_validator(["fp16_implementation", "dtype",
"loss_scale"])
def _check_fp16_implementation(flags_dict):
"""Validator to check fp16_implementation flag is valid."""
if (flags_dict["fp16_implementation"] == "graph_rewrite" and
flags_dict["dtype"] != "fp16"):
raise flags.ValidationError("--fp16_implementation should not be "
"specified unless --dtype=fp16")
return True
if all_reduce_alg:
flags.DEFINE_string(
name="all_reduce_alg", short_name="ara", default=None,
help=help_wrap("Defines the algorithm to use for performing all-reduce."
"When specified with MirroredStrategy for single "
"worker, this controls "
"tf.contrib.distribute.AllReduceCrossTowerOps. When "
"specified with MultiWorkerMirroredStrategy, this "
"controls "
"tf.distribute.experimental.CollectiveCommunication; "
"valid options are `ring` and `nccl`."))
if num_packs:
flags.DEFINE_integer(
name="num_packs", default=1,
help=help_wrap("Sets `num_packs` in the cross device ops used in "
"MirroredStrategy. For details, see "
"tf.distribute.NcclAllReduce."))
if tf_gpu_thread_mode:
flags.DEFINE_string(
name="tf_gpu_thread_mode", short_name="gt_mode", default=None,
help=help_wrap(
"Whether and how the GPU device uses its own threadpool.")
)
flags.DEFINE_integer(
name="per_gpu_thread_count", short_name="pgtc", default=0,
help=help_wrap(
"The number of threads to use for GPU. Only valid when "
"tf_gpu_thread_mode is not global.")
)
if datasets_num_private_threads:
flags.DEFINE_integer(
name="datasets_num_private_threads",
default=None,
help=help_wrap(
"Number of threads for a private threadpool created for all"
"datasets computation..")
)
if datasets_num_parallel_batches:
flags.DEFINE_integer(
name="datasets_num_parallel_batches",
default=None,
help=help_wrap(
"Determines how many batches to process in parallel when using "
"map and batch from tf.data.")
)
if training_dataset_cache:
flags.DEFINE_boolean(
name="training_dataset_cache",
default=False,
help=help_wrap(
"Determines whether to cache the training dataset on workers. "
"Typically used to improve training performance when training "
"data is in remote storage and can fit into worker memory.")
)
if tf_data_experimental_slack:
flags.DEFINE_boolean(
name="tf_data_experimental_slack",
default=False,
help=help_wrap(
"Whether to enable tf.data's `experimental_slack` option.")
)
if enable_xla:
flags.DEFINE_boolean(
name="enable_xla", default=False,
help="Whether to enable XLA auto jit compilation")
return key_flags