|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Common configuration settings.""" |
|
from typing import Optional, Union |
|
|
|
import dataclasses |
|
|
|
from official.modeling.hyperparams import base_config |
|
from official.modeling.optimization.configs import optimization_config |
|
from official.utils import registry |
|
|
|
OptimizationConfig = optimization_config.OptimizationConfig |
|
|
|
|
|
@dataclasses.dataclass |
|
class DataConfig(base_config.Config): |
|
"""The base configuration for building datasets. |
|
|
|
Attributes: |
|
input_path: The path to the input. It can be either (1) a file pattern, or |
|
(2) multiple file patterns separated by comma. It should not be specified |
|
when the following `tfds_name` is specified. |
|
tfds_name: The name of the tensorflow dataset (TFDS). It should not be |
|
specified when the above `input_path` is specified. |
|
tfds_split: A str indicating which split of the data to load from TFDS. It |
|
is required when above `tfds_name` is specified. |
|
global_batch_size: The global batch size across all replicas. |
|
is_training: Whether this data is used for training or not. |
|
drop_remainder: Whether the last batch should be dropped in the case it has |
|
fewer than `global_batch_size` elements. |
|
shuffle_buffer_size: The buffer size used for shuffling training data. |
|
cache: Whether to cache dataset examples. Can be used to avoid re-reading |
|
from disk on the second epoch. Requires significant memory overhead. |
|
cycle_length: The number of files that will be processed concurrently when |
|
interleaving files. |
|
block_length: The number of consecutive elements to produce from each input |
|
element before cycling to another input element when interleaving files. |
|
sharding: Whether sharding is used in the input pipeline. |
|
examples_consume: An `integer` specifying the number of examples it will |
|
produce. If positive, it only takes this number of examples and raises |
|
tf.error.OutOfRangeError after that. Default is -1, meaning it will |
|
exhaust all the examples in the dataset. |
|
tfds_data_dir: A str specifying the directory to read/write TFDS data. |
|
tfds_download: A bool to indicate whether to download data using TFDS. |
|
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, |
|
the returned tf.data.Dataset will have a 2-tuple structure (input, label) |
|
according to builder.info.supervised_keys; if False, the default, |
|
the returned tf.data.Dataset will have a dictionary with all the features. |
|
tfds_skip_decoding_feature: A str to indicate which features are skipped |
|
for decoding when loading dataset from TFDS. Use comma to separate |
|
multiple features. The main use case is to skip the image/video decoding |
|
for better performance. |
|
""" |
|
input_path: str = "" |
|
tfds_name: str = "" |
|
tfds_split: str = "" |
|
global_batch_size: int = 0 |
|
is_training: bool = None |
|
drop_remainder: bool = True |
|
shuffle_buffer_size: int = 100 |
|
cache: bool = False |
|
cycle_length: int = 8 |
|
block_length: int = 1 |
|
sharding: bool = True |
|
examples_consume: int = -1 |
|
tfds_data_dir: str = "" |
|
tfds_download: bool = False |
|
tfds_as_supervised: bool = False |
|
tfds_skip_decoding_feature: str = "" |
|
|
|
|
|
@dataclasses.dataclass |
|
class RuntimeConfig(base_config.Config): |
|
"""High-level configurations for Runtime. |
|
|
|
These include parameters that are not directly related to the experiment, |
|
e.g. directories, accelerator type, etc. |
|
|
|
Attributes: |
|
distribution_strategy: e.g. 'mirrored', 'tpu', etc. |
|
enable_xla: Whether or not to enable XLA. |
|
per_gpu_thread_count: thread count per GPU. |
|
gpu_thread_mode: Whether and how the GPU device uses its own threadpool. |
|
dataset_num_private_threads: Number of threads for a private threadpool |
|
created for all datasets computation. |
|
tpu: The address of the TPU to use, if any. |
|
num_gpus: The number of GPUs to use, if any. |
|
worker_hosts: comma-separated list of worker ip:port pairs for running |
|
multi-worker models with DistributionStrategy. |
|
task_index: If multi-worker training, the task index of this worker. |
|
all_reduce_alg: Defines the algorithm for performing all-reduce. |
|
num_packs: Sets `num_packs` in the cross device ops used in |
|
MirroredStrategy. For details, see tf.distribute.NcclAllReduce. |
|
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32', |
|
'float16', or 'bfloat16'. |
|
loss_scale: The type of loss scale, or 'float' value. This is used when |
|
setting the mixed precision policy. |
|
run_eagerly: Whether or not to run the experiment eagerly. |
|
batchnorm_spatial_persistent: Whether or not to enable the spatial |
|
persistent mode for CuDNN batch norm kernel for improved GPU performance. |
|
""" |
|
distribution_strategy: str = "mirrored" |
|
enable_xla: bool = False |
|
gpu_thread_mode: Optional[str] = None |
|
dataset_num_private_threads: Optional[int] = None |
|
per_gpu_thread_count: int = 0 |
|
tpu: Optional[str] = None |
|
num_gpus: int = 0 |
|
worker_hosts: Optional[str] = None |
|
task_index: int = -1 |
|
all_reduce_alg: Optional[str] = None |
|
num_packs: int = 1 |
|
loss_scale: Optional[Union[str, float]] = None |
|
mixed_precision_dtype: Optional[str] = None |
|
run_eagerly: bool = False |
|
batchnorm_spatial_persistent: bool = False |
|
|
|
|
|
@dataclasses.dataclass |
|
class TensorboardConfig(base_config.Config): |
|
"""Configuration for Tensorboard. |
|
|
|
Attributes: |
|
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults |
|
to True. |
|
write_model_weights: Whether or not to write the model weights as images in |
|
Tensorboard. Defaults to False. |
|
""" |
|
track_lr: bool = True |
|
write_model_weights: bool = False |
|
|
|
|
|
@dataclasses.dataclass |
|
class CallbacksConfig(base_config.Config): |
|
"""Configuration for Callbacks. |
|
|
|
Attributes: |
|
enable_checkpoint_and_export: Whether or not to enable checkpoints as a |
|
Callback. Defaults to True. |
|
enable_tensorboard: Whether or not to enable Tensorboard as a Callback. |
|
Defaults to True. |
|
enable_time_history: Whether or not to enable TimeHistory Callbacks. |
|
Defaults to True. |
|
""" |
|
enable_checkpoint_and_export: bool = True |
|
enable_tensorboard: bool = True |
|
enable_time_history: bool = True |
|
|
|
|
|
@dataclasses.dataclass |
|
class TrainerConfig(base_config.Config): |
|
"""Configuration for trainer. |
|
|
|
Attributes: |
|
optimizer_config: optimizer config, it includes optimizer, learning rate, |
|
and warmup schedule configs. |
|
train_tf_while_loop: whether or not to use tf while loop. |
|
train_tf_function: whether or not to use tf_function for training loop. |
|
eval_tf_function: whether or not to use tf_function for eval. |
|
steps_per_loop: number of steps per loop. |
|
summary_interval: number of steps between each summary. |
|
checkpoint_intervals: number of steps between checkpoints. |
|
max_to_keep: max checkpoints to keep. |
|
continuous_eval_timeout: maximum number of seconds to wait between |
|
checkpoints, if set to None, continuous eval will wait indefinetely. |
|
""" |
|
optimizer_config: OptimizationConfig = OptimizationConfig() |
|
train_tf_while_loop: bool = True |
|
train_tf_function: bool = True |
|
eval_tf_function: bool = True |
|
steps_per_loop: int = 1000 |
|
summary_interval: int = 1000 |
|
checkpoint_interval: int = 1000 |
|
max_to_keep: int = 5 |
|
continuous_eval_timeout: Optional[int] = None |
|
|
|
|
|
@dataclasses.dataclass |
|
class TaskConfig(base_config.Config): |
|
network: base_config.Config = None |
|
train_data: DataConfig = DataConfig() |
|
validation_data: DataConfig = DataConfig() |
|
|
|
|
|
@dataclasses.dataclass |
|
class ExperimentConfig(base_config.Config): |
|
"""Top-level configuration.""" |
|
task: TaskConfig = TaskConfig() |
|
trainer: TrainerConfig = TrainerConfig() |
|
runtime: RuntimeConfig = RuntimeConfig() |
|
train_steps: int = 0 |
|
validation_steps: Optional[int] = None |
|
validation_interval: int = 100 |
|
|
|
|
|
_REGISTERED_CONFIGS = {} |
|
|
|
|
|
def register_config_factory(name): |
|
"""Register ExperimentConfig factory method.""" |
|
return registry.register(_REGISTERED_CONFIGS, name) |
|
|
|
|
|
def get_exp_config_creater(exp_name: str): |
|
"""Looks up ExperimentConfig factory methods.""" |
|
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name) |
|
return exp_creater |
|
|