Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Flags which will be nearly universal across models.""" | |
from absl import flags | |
import tensorflow as tf, tf_keras | |
from official.utils.flags._conventions import help_wrap | |
def define_base(data_dir=True, | |
model_dir=True, | |
clean=False, | |
train_epochs=False, | |
epochs_between_evals=False, | |
stop_threshold=False, | |
batch_size=True, | |
num_gpu=False, | |
hooks=False, | |
export_dir=False, | |
distribution_strategy=False, | |
run_eagerly=False): | |
"""Register base flags. | |
Args: | |
data_dir: Create a flag for specifying the input data directory. | |
model_dir: Create a flag for specifying the model file directory. | |
clean: Create a flag for removing the model_dir. | |
train_epochs: Create a flag to specify the number of training epochs. | |
epochs_between_evals: Create a flag to specify the frequency of testing. | |
stop_threshold: Create a flag to specify a threshold accuracy or other eval | |
metric which should trigger the end of training. | |
batch_size: Create a flag to specify the batch size. | |
num_gpu: Create a flag to specify the number of GPUs used. | |
hooks: Create a flag to specify hooks for logging. | |
export_dir: Create a flag to specify where a SavedModel should be exported. | |
distribution_strategy: Create a flag to specify which Distribution Strategy | |
to use. | |
run_eagerly: Create a flag to specify to run eagerly op by op. | |
Returns: | |
A list of flags for core.py to marks as key flags. | |
""" | |
key_flags = [] | |
if data_dir: | |
flags.DEFINE_string( | |
name="data_dir", | |
short_name="dd", | |
default="/tmp", | |
help=help_wrap("The location of the input data.")) | |
key_flags.append("data_dir") | |
if model_dir: | |
flags.DEFINE_string( | |
name="model_dir", | |
short_name="md", | |
default="/tmp", | |
help=help_wrap("The location of the model checkpoint files.")) | |
key_flags.append("model_dir") | |
if clean: | |
flags.DEFINE_boolean( | |
name="clean", | |
default=False, | |
help=help_wrap("If set, model_dir will be removed if it exists.")) | |
key_flags.append("clean") | |
if train_epochs: | |
flags.DEFINE_integer( | |
name="train_epochs", | |
short_name="te", | |
default=1, | |
help=help_wrap("The number of epochs used to train.")) | |
key_flags.append("train_epochs") | |
if epochs_between_evals: | |
flags.DEFINE_integer( | |
name="epochs_between_evals", | |
short_name="ebe", | |
default=1, | |
help=help_wrap("The number of training epochs to run between " | |
"evaluations.")) | |
key_flags.append("epochs_between_evals") | |
if stop_threshold: | |
flags.DEFINE_float( | |
name="stop_threshold", | |
short_name="st", | |
default=None, | |
help=help_wrap("If passed, training will stop at the earlier of " | |
"train_epochs and when the evaluation metric is " | |
"greater than or equal to stop_threshold.")) | |
if batch_size: | |
flags.DEFINE_integer( | |
name="batch_size", | |
short_name="bs", | |
default=32, | |
help=help_wrap("Batch size for training and evaluation. When using " | |
"multiple gpus, this is the global batch size for " | |
"all devices. For example, if the batch size is 32 " | |
"and there are 4 GPUs, each GPU will get 8 examples on " | |
"each step.")) | |
key_flags.append("batch_size") | |
if num_gpu: | |
flags.DEFINE_integer( | |
name="num_gpus", | |
short_name="ng", | |
default=1, | |
help=help_wrap("How many GPUs to use at each worker with the " | |
"DistributionStrategies API. The default is 1.")) | |
if run_eagerly: | |
flags.DEFINE_boolean( | |
name="run_eagerly", | |
default=False, | |
help="Run the model op by op without building a model function.") | |
if hooks: | |
flags.DEFINE_list( | |
name="hooks", | |
short_name="hk", | |
default="LoggingTensorHook", | |
help=help_wrap( | |
u"A list of (case insensitive) strings to specify the names of " | |
u"training hooks. Example: `--hooks ProfilerHook," | |
u"ExamplesPerSecondHook`\n See hooks_helper " | |
u"for details.")) | |
key_flags.append("hooks") | |
if export_dir: | |
flags.DEFINE_string( | |
name="export_dir", | |
short_name="ed", | |
default=None, | |
help=help_wrap("If set, a SavedModel serialization of the model will " | |
"be exported to this directory at the end of training. " | |
"See the README for more details and relevant links.")) | |
key_flags.append("export_dir") | |
if distribution_strategy: | |
flags.DEFINE_string( | |
name="distribution_strategy", | |
short_name="ds", | |
default="mirrored", | |
help=help_wrap("The Distribution Strategy to use for training. " | |
"Accepted values are 'off', 'one_device', " | |
"'mirrored', 'parameter_server', 'collective', " | |
"case insensitive. 'off' means not to use " | |
"Distribution Strategy; 'default' means to choose " | |
"from `MirroredStrategy` or `OneDeviceStrategy` " | |
"according to the number of GPUs.")) | |
return key_flags | |
def get_num_gpus(flags_obj): | |
"""Treat num_gpus=-1 as 'use all'.""" | |
if flags_obj.num_gpus != -1: | |
return flags_obj.num_gpus | |
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top | |
local_device_protos = device_lib.list_local_devices() | |
return sum([1 for d in local_device_protos if d.device_type == "GPU"]) | |