Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Register flags for optimizing performance.""" | |
import multiprocessing | |
from absl import flags # pylint: disable=g-bad-import-order | |
import tensorflow as tf, tf_keras # pylint: disable=g-bad-import-order | |
from official.utils.flags._conventions import help_wrap | |
# Map string to TensorFlow dtype | |
DTYPE_MAP = { | |
"fp16": tf.float16, | |
"bf16": tf.bfloat16, | |
"fp32": tf.float32, | |
} | |
def get_tf_dtype(flags_obj): | |
if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite": | |
# If the graph_rewrite is used, we build the graph with fp32, and let the | |
# graph rewrite change ops to fp16. | |
return tf.float32 | |
return DTYPE_MAP[flags_obj.dtype] | |
def get_loss_scale(flags_obj, default_for_fp16): | |
dtype = get_tf_dtype(flags_obj) | |
if flags_obj.loss_scale == "dynamic": | |
return flags_obj.loss_scale | |
elif flags_obj.loss_scale is not None: | |
return float(flags_obj.loss_scale) | |
elif dtype == tf.float32 or dtype == tf.bfloat16: | |
return 1 # No loss scaling is needed for fp32 | |
else: | |
assert dtype == tf.float16 | |
return default_for_fp16 | |
def define_performance(num_parallel_calls=False, | |
inter_op=False, | |
intra_op=False, | |
synthetic_data=False, | |
max_train_steps=False, | |
dtype=False, | |
all_reduce_alg=False, | |
num_packs=False, | |
tf_gpu_thread_mode=False, | |
datasets_num_private_threads=False, | |
datasets_num_parallel_batches=False, | |
fp16_implementation=False, | |
loss_scale=False, | |
tf_data_experimental_slack=False, | |
enable_xla=False, | |
training_dataset_cache=False): | |
"""Register flags for specifying performance tuning arguments. | |
Args: | |
num_parallel_calls: Create a flag to specify parallelism of data loading. | |
inter_op: Create a flag to allow specification of inter op threads. | |
intra_op: Create a flag to allow specification of intra op threads. | |
synthetic_data: Create a flag to allow the use of synthetic data. | |
max_train_steps: Create a flags to allow specification of maximum number of | |
training steps | |
dtype: Create flags for specifying dtype. | |
all_reduce_alg: If set forces a specific algorithm for multi-gpu. | |
num_packs: If set provides number of packs for MirroredStrategy's cross | |
device ops. | |
tf_gpu_thread_mode: gpu_private triggers us of private thread pool. | |
datasets_num_private_threads: Number of private threads for datasets. | |
datasets_num_parallel_batches: Determines how many batches to process in | |
parallel when using map and batch from tf.data. | |
fp16_implementation: Create fp16_implementation flag. | |
loss_scale: Controls the loss scaling, normally for mixed-precision | |
training. Can only be turned on if dtype is also True. | |
tf_data_experimental_slack: Determines whether to enable tf.data's | |
`experimental_slack` option. | |
enable_xla: Determines if XLA (auto clustering) is turned on. | |
training_dataset_cache: Whether to cache the training dataset on workers. | |
Typically used to improve training performance when training data is in | |
remote storage and can fit into worker memory. | |
Returns: | |
A list of flags for core.py to marks as key flags. | |
""" | |
key_flags = [] | |
if num_parallel_calls: | |
flags.DEFINE_integer( | |
name="num_parallel_calls", | |
short_name="npc", | |
default=multiprocessing.cpu_count(), | |
help=help_wrap("The number of records that are processed in parallel " | |
"during input processing. This can be optimized per " | |
"data set but for generally homogeneous data sets, " | |
"should be approximately the number of available CPU " | |
"cores. (default behavior)")) | |
if inter_op: | |
flags.DEFINE_integer( | |
name="inter_op_parallelism_threads", | |
short_name="inter", | |
default=0, | |
help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. " | |
"See TensorFlow config.proto for details.")) | |
if intra_op: | |
flags.DEFINE_integer( | |
name="intra_op_parallelism_threads", | |
short_name="intra", | |
default=0, | |
help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. " | |
"See TensorFlow config.proto for details.")) | |
if synthetic_data: | |
flags.DEFINE_bool( | |
name="use_synthetic_data", | |
short_name="synth", | |
default=False, | |
help=help_wrap( | |
"If set, use fake data (zeroes) instead of a real dataset. " | |
"This mode is useful for performance debugging, as it removes " | |
"input processing steps, but will not learn anything.")) | |
if max_train_steps: | |
flags.DEFINE_integer( | |
name="max_train_steps", | |
short_name="mts", | |
default=None, | |
help=help_wrap( | |
"The model will stop training if the global_step reaches this " | |
"value. If not set, training will run until the specified number " | |
"of epochs have run as usual. It is generally recommended to set " | |
"--train_epochs=1 when using this flag.")) | |
if dtype: | |
flags.DEFINE_enum( | |
name="dtype", | |
short_name="dt", | |
default="fp32", | |
enum_values=DTYPE_MAP.keys(), | |
help=help_wrap("The TensorFlow datatype used for calculations. " | |
"For 16-bit dtypes, variables and certain ops will " | |
"still be float32 for numeric stability.")) | |
if loss_scale: | |
flags.DEFINE_string( | |
name="loss_scale", | |
short_name="ls", | |
default=None, | |
help=help_wrap( | |
"The amount to scale the loss by when --dtype=fp16. This can be " | |
"an int/float or the string 'dynamic'. Before gradients are " | |
"computed, the loss is multiplied by the loss scale, making all " | |
"gradients loss_scale times larger. To adjust for this, " | |
"gradients are divided by the loss scale before being applied to " | |
"variables. This is mathematically equivalent to training " | |
"without a loss scale, but the loss scale helps avoid some " | |
"intermediate gradients from underflowing to zero. The default " | |
"is 'dynamic', which dynamic determines the optimal loss scale " | |
"during training.")) | |
# pylint: disable=unused-variable | |
def _check_loss_scale(loss_scale): | |
"""Validator to check the loss scale flag is valid.""" | |
if loss_scale is None: | |
return True # null case is handled in get_loss_scale() | |
if loss_scale == "dynamic": | |
return True | |
try: | |
loss_scale = float(loss_scale) | |
except ValueError: | |
return False | |
return loss_scale > 0 | |
# pylint: enable=unused-variable | |
if fp16_implementation: | |
flags.DEFINE_enum( | |
name="fp16_implementation", | |
default="keras", | |
enum_values=("keras", "graph_rewrite"), | |
help=help_wrap( | |
"When --dtype=fp16, how fp16 should be implemented. This has no " | |
"impact on correctness. 'keras' uses the " | |
"tf_keras.mixed_precision API. 'graph_rewrite' uses the " | |
"tf.compat.v1.mixed_precision." | |
"enable_mixed_precision_graph_rewrite API.")) | |
def _check_fp16_implementation(flags_dict): | |
"""Validator to check fp16_implementation flag is valid.""" | |
if (flags_dict["fp16_implementation"] == "graph_rewrite" and | |
flags_dict["dtype"] != "fp16"): | |
raise flags.ValidationError("--fp16_implementation should not be " | |
"specified unless --dtype=fp16") | |
return True | |
if all_reduce_alg: | |
flags.DEFINE_string( | |
name="all_reduce_alg", | |
short_name="ara", | |
default=None, | |
help=help_wrap("Defines the algorithm to use for performing all-reduce." | |
"When specified with MirroredStrategy for single " | |
"worker, this controls " | |
"tf.contrib.distribute.AllReduceCrossTowerOps. When " | |
"specified with MultiWorkerMirroredStrategy, this " | |
"controls " | |
"tf.distribute.experimental.CollectiveCommunication; " | |
"valid options are `ring` and `nccl`.")) | |
if num_packs: | |
flags.DEFINE_integer( | |
name="num_packs", | |
default=1, | |
help=help_wrap("Sets `num_packs` in the cross device ops used in " | |
"MirroredStrategy. For details, see " | |
"tf.distribute.NcclAllReduce.")) | |
if tf_gpu_thread_mode: | |
flags.DEFINE_string( | |
name="tf_gpu_thread_mode", | |
short_name="gt_mode", | |
default=None, | |
help=help_wrap( | |
"Whether and how the GPU device uses its own threadpool.")) | |
flags.DEFINE_integer( | |
name="per_gpu_thread_count", | |
short_name="pgtc", | |
default=0, | |
help=help_wrap("The number of threads to use for GPU. Only valid when " | |
"tf_gpu_thread_mode is not global.")) | |
if datasets_num_private_threads: | |
flags.DEFINE_integer( | |
name="datasets_num_private_threads", | |
default=None, | |
help=help_wrap( | |
"Number of threads for a private threadpool created for all" | |
"datasets computation..")) | |
if datasets_num_parallel_batches: | |
flags.DEFINE_integer( | |
name="datasets_num_parallel_batches", | |
default=None, | |
help=help_wrap( | |
"Determines how many batches to process in parallel when using " | |
"map and batch from tf.data.")) | |
if training_dataset_cache: | |
flags.DEFINE_boolean( | |
name="training_dataset_cache", | |
default=False, | |
help=help_wrap( | |
"Determines whether to cache the training dataset on workers. " | |
"Typically used to improve training performance when training " | |
"data is in remote storage and can fit into worker memory.")) | |
if tf_data_experimental_slack: | |
flags.DEFINE_boolean( | |
name="tf_data_experimental_slack", | |
default=False, | |
help=help_wrap( | |
"Whether to enable tf.data's `experimental_slack` option.")) | |
if enable_xla: | |
flags.DEFINE_boolean( | |
name="enable_xla", | |
default=False, | |
help="Whether to enable XLA auto jit compilation") | |
return key_flags | |