File size: 5,304 Bytes
97b6013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataclasses for optimizer configs."""
from typing import List, Optional
import dataclasses
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class SGDConfig(base_config.Config):
"""Configuration for SGD optimizer.
The attributes for this class matches the arguments of tf.keras.optimizer.SGD.
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for SGD optimizer.
decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer.
"""
name: str = "SGD"
learning_rate: float = 0.01
decay: float = 0.0
nesterov: bool = False
momentum: float = 0.0
@dataclasses.dataclass
class RMSPropConfig(base_config.Config):
"""Configuration for RMSProp optimizer.
The attributes for this class matches the arguments of
tf.keras.optimizers.RMSprop.
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not.
"""
name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9
momentum: float = 0.0
epsilon: float = 1e-7
centered: bool = False
@dataclasses.dataclass
class AdamConfig(base_config.Config):
"""Configuration for Adam optimizer.
The attributes for this class matches the arguments of
tf.keras.optimizer.Adam.
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
"""
name: str = "Adam"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
amsgrad: bool = False
@dataclasses.dataclass
class AdamWeightDecayConfig(base_config.Config):
"""Configuration for Adam optimizer with weight decay.
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
weight_decay_rate: float. Weight decay rate. Default to 0.
include_in_weight_decay: list[str], or None. List of weight names to include
in weight decay.
include_in_weight_decay: list[str], or None. List of weight names to not
include in weight decay.
"""
name: str = "AdamWeightDecay"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
amsgrad: bool = False
weight_decay_rate: float = 0.0
include_in_weight_decay: Optional[List[str]] = None
exclude_from_weight_decay: Optional[List[str]] = None
@dataclasses.dataclass
class LAMBConfig(base_config.Config):
"""Configuration for LAMB optimizer.
The attributes for this class matches the arguments of
tensorflow_addons.optimizers.LAMB.
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in LAMB optimizer.
weight_decay_rate: float. Weight decay rate. Default to 0.
exclude_from_weight_decay: List of regex patterns of variables excluded from
weight decay. Variables whose name contain a
substring matching the pattern will be excluded.
exclude_from_layer_adaptation: List of regex patterns of variables excluded
from layer adaptation. Variables whose name
contain a substring matching the pattern will
be excluded.
"""
name: str = "LAMB"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-6
weight_decay_rate: float = 0.0
exclude_from_weight_decay: Optional[List[str]] = None
exclude_from_layer_adaptation: Optional[List[str]] = None
|