Spaces:
Runtime error
Runtime error
File size: 4,286 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The central place to define flags."""
from absl import flags
def define_flags():
"""Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags.DEFINE_string(
'experiment', default=None, help=
'The experiment type registered, specifying an ExperimentConfig.')
flags.DEFINE_enum(
'mode',
default=None,
enum_values=[
'train', 'eval', 'train_and_eval', 'continuous_eval',
'continuous_train_and_eval', 'train_and_validate',
'train_and_post_eval'
],
help='Mode to run: `train`, `eval`, `train_and_eval`, '
'`continuous_eval`, `continuous_train_and_eval` and '
'`train_and_validate` (which is not implemented in '
'the open source version).')
flags.DEFINE_string(
'model_dir',
default=None,
help='The directory where the model and training/evaluation summaries'
'are stored.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override',
default=None,
help='a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to override '
'the model parameters, but not the parameters like TPU specific flags. '
'One canonical use case of `--config_file` and `--params_override` is '
'users first define a template config file using `--config_file`, then '
'use `--params_override` to adjust the minimal set of tuning parameters, '
'for example setting up different `train_batch_size`. The final override '
'order of parameters: default_model_params --> params from config_file '
'--> params in params_override. See also the help message of '
'`--config_file`.')
# The libraries rely on gin often make mistakes that include flags inside
# the library files which causes conflicts.
try:
flags.DEFINE_multi_string(
'gin_file', default=None, help='List of paths to the config files.')
except flags.DuplicateFlagError:
pass
try:
flags.DEFINE_multi_string(
'gin_params',
default=None,
help='Newline separated list of Gin parameter bindings.')
except flags.DuplicateFlagError:
pass
flags.DEFINE_string(
'tpu',
default=None,
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.')
flags.DEFINE_string(
'tf_data_service', default=None, help='The tf.data service address')
flags.DEFINE_string(
'tpu_platform', default=None, help='TPU platform type.')
|