# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The central place to define flags.""" from absl import flags def define_flags(): """Defines flags. All flags are defined as optional, but in practice most models use some of these flags and so mark_flags_as_required() should be called after calling this function. Typically, 'experiment', 'mode', and 'model_dir' are required. For example: ``` from absl import flags from official.common import flags as tfm_flags # pylint: disable=line-too-long ... tfm_flags.define_flags() flags.mark_flags_as_required(['experiment', 'mode', 'model_dir']) ``` The reason all flags are optional is because unit tests often do not set or use any of the flags. """ flags.DEFINE_string( 'experiment', default=None, help= 'The experiment type registered, specifying an ExperimentConfig.') flags.DEFINE_enum( 'mode', default=None, enum_values=[ 'train', 'eval', 'train_and_eval', 'continuous_eval', 'continuous_train_and_eval', 'train_and_validate', 'train_and_post_eval' ], help='Mode to run: `train`, `eval`, `train_and_eval`, ' '`continuous_eval`, `continuous_train_and_eval` and ' '`train_and_validate` (which is not implemented in ' 'the open source version).') flags.DEFINE_string( 'model_dir', default=None, help='The directory where the model and training/evaluation summaries' 'are stored.') flags.DEFINE_multi_string( 'config_file', default=None, help='YAML/JSON files which specifies overrides. The override order ' 'follows the order of args. Note that each file ' 'can be used as an override template to override the default parameters ' 'specified in Python. If the same parameter is specified in both ' '`--config_file` and `--params_override`, `config_file` will be used ' 'first, followed by params_override.') flags.DEFINE_string( 'params_override', default=None, help='a YAML/JSON string or a YAML file which specifies additional ' 'overrides over the default parameters and those specified in ' '`--config_file`. Note that this is supposed to be used only to override ' 'the model parameters, but not the parameters like TPU specific flags. ' 'One canonical use case of `--config_file` and `--params_override` is ' 'users first define a template config file using `--config_file`, then ' 'use `--params_override` to adjust the minimal set of tuning parameters, ' 'for example setting up different `train_batch_size`. The final override ' 'order of parameters: default_model_params --> params from config_file ' '--> params in params_override. See also the help message of ' '`--config_file`.') # The libraries rely on gin often make mistakes that include flags inside # the library files which causes conflicts. try: flags.DEFINE_multi_string( 'gin_file', default=None, help='List of paths to the config files.') except flags.DuplicateFlagError: pass try: flags.DEFINE_multi_string( 'gin_params', default=None, help='Newline separated list of Gin parameter bindings.') except flags.DuplicateFlagError: pass flags.DEFINE_string( 'tpu', default=None, help='The Cloud TPU to use for training. This should be either the name ' 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 ' 'url.') flags.DEFINE_string( 'tf_data_service', default=None, help='The tf.data service address') flags.DEFINE_string( 'tpu_platform', default=None, help='TPU platform type.')