|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Flags which will be nearly universal across models.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
from absl import flags |
|
import tensorflow as tf |
|
from official.utils.flags._conventions import help_wrap |
|
|
|
|
|
def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, |
|
epochs_between_evals=False, stop_threshold=False, |
|
batch_size=True, num_gpu=False, hooks=False, export_dir=False, |
|
distribution_strategy=False, run_eagerly=False): |
|
"""Register base flags. |
|
|
|
Args: |
|
data_dir: Create a flag for specifying the input data directory. |
|
model_dir: Create a flag for specifying the model file directory. |
|
clean: Create a flag for removing the model_dir. |
|
train_epochs: Create a flag to specify the number of training epochs. |
|
epochs_between_evals: Create a flag to specify the frequency of testing. |
|
stop_threshold: Create a flag to specify a threshold accuracy or other |
|
eval metric which should trigger the end of training. |
|
batch_size: Create a flag to specify the batch size. |
|
num_gpu: Create a flag to specify the number of GPUs used. |
|
hooks: Create a flag to specify hooks for logging. |
|
export_dir: Create a flag to specify where a SavedModel should be exported. |
|
distribution_strategy: Create a flag to specify which Distribution Strategy |
|
to use. |
|
run_eagerly: Create a flag to specify to run eagerly op by op. |
|
Returns: |
|
A list of flags for core.py to marks as key flags. |
|
""" |
|
key_flags = [] |
|
|
|
if data_dir: |
|
flags.DEFINE_string( |
|
name="data_dir", short_name="dd", default="/tmp", |
|
help=help_wrap("The location of the input data.")) |
|
key_flags.append("data_dir") |
|
|
|
if model_dir: |
|
flags.DEFINE_string( |
|
name="model_dir", short_name="md", default="/tmp", |
|
help=help_wrap("The location of the model checkpoint files.")) |
|
key_flags.append("model_dir") |
|
|
|
if clean: |
|
flags.DEFINE_boolean( |
|
name="clean", default=False, |
|
help=help_wrap("If set, model_dir will be removed if it exists.")) |
|
key_flags.append("clean") |
|
|
|
if train_epochs: |
|
flags.DEFINE_integer( |
|
name="train_epochs", short_name="te", default=1, |
|
help=help_wrap("The number of epochs used to train.")) |
|
key_flags.append("train_epochs") |
|
|
|
if epochs_between_evals: |
|
flags.DEFINE_integer( |
|
name="epochs_between_evals", short_name="ebe", default=1, |
|
help=help_wrap("The number of training epochs to run between " |
|
"evaluations.")) |
|
key_flags.append("epochs_between_evals") |
|
|
|
if stop_threshold: |
|
flags.DEFINE_float( |
|
name="stop_threshold", short_name="st", |
|
default=None, |
|
help=help_wrap("If passed, training will stop at the earlier of " |
|
"train_epochs and when the evaluation metric is " |
|
"greater than or equal to stop_threshold.")) |
|
|
|
if batch_size: |
|
flags.DEFINE_integer( |
|
name="batch_size", short_name="bs", default=32, |
|
help=help_wrap("Batch size for training and evaluation. When using " |
|
"multiple gpus, this is the global batch size for " |
|
"all devices. For example, if the batch size is 32 " |
|
"and there are 4 GPUs, each GPU will get 8 examples on " |
|
"each step.")) |
|
key_flags.append("batch_size") |
|
|
|
if num_gpu: |
|
flags.DEFINE_integer( |
|
name="num_gpus", short_name="ng", |
|
default=1, |
|
help=help_wrap( |
|
"How many GPUs to use at each worker with the " |
|
"DistributionStrategies API. The default is 1.")) |
|
|
|
if run_eagerly: |
|
flags.DEFINE_boolean( |
|
name="run_eagerly", default=False, |
|
help="Run the model op by op without building a model function.") |
|
|
|
if hooks: |
|
flags.DEFINE_list( |
|
name="hooks", short_name="hk", default="LoggingTensorHook", |
|
help=help_wrap( |
|
u"A list of (case insensitive) strings to specify the names of " |
|
u"training hooks. Example: `--hooks ProfilerHook," |
|
u"ExamplesPerSecondHook`\n See hooks_helper " |
|
u"for details.") |
|
) |
|
key_flags.append("hooks") |
|
|
|
if export_dir: |
|
flags.DEFINE_string( |
|
name="export_dir", short_name="ed", default=None, |
|
help=help_wrap("If set, a SavedModel serialization of the model will " |
|
"be exported to this directory at the end of training. " |
|
"See the README for more details and relevant links.") |
|
) |
|
key_flags.append("export_dir") |
|
|
|
if distribution_strategy: |
|
flags.DEFINE_string( |
|
name="distribution_strategy", short_name="ds", default="mirrored", |
|
help=help_wrap("The Distribution Strategy to use for training. " |
|
"Accepted values are 'off', 'one_device', " |
|
"'mirrored', 'parameter_server', 'collective', " |
|
"case insensitive. 'off' means not to use " |
|
"Distribution Strategy; 'default' means to choose " |
|
"from `MirroredStrategy` or `OneDeviceStrategy` " |
|
"according to the number of GPUs.") |
|
) |
|
|
|
|
|
return key_flags |
|
|
|
|
|
def get_num_gpus(flags_obj): |
|
"""Treat num_gpus=-1 as 'use all'.""" |
|
if flags_obj.num_gpus != -1: |
|
return flags_obj.num_gpus |
|
|
|
from tensorflow.python.client import device_lib |
|
local_device_protos = device_lib.list_local_devices() |
|
return sum([1 for d in local_device_protos if d.device_type == "GPU"]) |
|
|