deanna-emery's picture
updates
93528c6
raw
history blame
6.41 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags which will be nearly universal across models."""
from absl import flags
import tensorflow as tf, tf_keras
from official.utils.flags._conventions import help_wrap
def define_base(data_dir=True,
model_dir=True,
clean=False,
train_epochs=False,
epochs_between_evals=False,
stop_threshold=False,
batch_size=True,
num_gpu=False,
hooks=False,
export_dir=False,
distribution_strategy=False,
run_eagerly=False):
"""Register base flags.
Args:
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
clean: Create a flag for removing the model_dir.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other eval
metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
distribution_strategy: Create a flag to specify which Distribution Strategy
to use.
run_eagerly: Create a flag to specify to run eagerly op by op.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if data_dir:
flags.DEFINE_string(
name="data_dir",
short_name="dd",
default="/tmp",
help=help_wrap("The location of the input data."))
key_flags.append("data_dir")
if model_dir:
flags.DEFINE_string(
name="model_dir",
short_name="md",
default="/tmp",
help=help_wrap("The location of the model checkpoint files."))
key_flags.append("model_dir")
if clean:
flags.DEFINE_boolean(
name="clean",
default=False,
help=help_wrap("If set, model_dir will be removed if it exists."))
key_flags.append("clean")
if train_epochs:
flags.DEFINE_integer(
name="train_epochs",
short_name="te",
default=1,
help=help_wrap("The number of epochs used to train."))
key_flags.append("train_epochs")
if epochs_between_evals:
flags.DEFINE_integer(
name="epochs_between_evals",
short_name="ebe",
default=1,
help=help_wrap("The number of training epochs to run between "
"evaluations."))
key_flags.append("epochs_between_evals")
if stop_threshold:
flags.DEFINE_float(
name="stop_threshold",
short_name="st",
default=None,
help=help_wrap("If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold."))
if batch_size:
flags.DEFINE_integer(
name="batch_size",
short_name="bs",
default=32,
help=help_wrap("Batch size for training and evaluation. When using "
"multiple gpus, this is the global batch size for "
"all devices. For example, if the batch size is 32 "
"and there are 4 GPUs, each GPU will get 8 examples on "
"each step."))
key_flags.append("batch_size")
if num_gpu:
flags.DEFINE_integer(
name="num_gpus",
short_name="ng",
default=1,
help=help_wrap("How many GPUs to use at each worker with the "
"DistributionStrategies API. The default is 1."))
if run_eagerly:
flags.DEFINE_boolean(
name="run_eagerly",
default=False,
help="Run the model op by op without building a model function.")
if hooks:
flags.DEFINE_list(
name="hooks",
short_name="hk",
default="LoggingTensorHook",
help=help_wrap(
u"A list of (case insensitive) strings to specify the names of "
u"training hooks. Example: `--hooks ProfilerHook,"
u"ExamplesPerSecondHook`\n See hooks_helper "
u"for details."))
key_flags.append("hooks")
if export_dir:
flags.DEFINE_string(
name="export_dir",
short_name="ed",
default=None,
help=help_wrap("If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. "
"See the README for more details and relevant links."))
key_flags.append("export_dir")
if distribution_strategy:
flags.DEFINE_string(
name="distribution_strategy",
short_name="ds",
default="mirrored",
help=help_wrap("The Distribution Strategy to use for training. "
"Accepted values are 'off', 'one_device', "
"'mirrored', 'parameter_server', 'collective', "
"case insensitive. 'off' means not to use "
"Distribution Strategy; 'default' means to choose "
"from `MirroredStrategy` or `OneDeviceStrategy` "
"according to the number of GPUs."))
return key_flags
def get_num_gpus(flags_obj):
"""Treat num_gpus=-1 as 'use all'."""
if flags_obj.num_gpus != -1:
return flags_obj.num_gpus
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
local_device_protos = device_lib.list_local_devices()
return sum([1 for d in local_device_protos if d.device_type == "GPU"])