|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Dataclasses for optimizer configs.""" |
|
from typing import List, Optional |
|
|
|
import dataclasses |
|
from official.modeling.hyperparams import base_config |
|
|
|
|
|
@dataclasses.dataclass |
|
class SGDConfig(base_config.Config): |
|
"""Configuration for SGD optimizer. |
|
|
|
The attributes for this class matches the arguments of tf.keras.optimizer.SGD. |
|
|
|
Attributes: |
|
name: name of the optimizer. |
|
learning_rate: learning_rate for SGD optimizer. |
|
decay: decay rate for SGD optimizer. |
|
nesterov: nesterov for SGD optimizer. |
|
momentum: momentum for SGD optimizer. |
|
""" |
|
name: str = "SGD" |
|
learning_rate: float = 0.01 |
|
decay: float = 0.0 |
|
nesterov: bool = False |
|
momentum: float = 0.0 |
|
|
|
|
|
@dataclasses.dataclass |
|
class RMSPropConfig(base_config.Config): |
|
"""Configuration for RMSProp optimizer. |
|
|
|
The attributes for this class matches the arguments of |
|
tf.keras.optimizers.RMSprop. |
|
|
|
Attributes: |
|
name: name of the optimizer. |
|
learning_rate: learning_rate for RMSprop optimizer. |
|
rho: discounting factor for RMSprop optimizer. |
|
momentum: momentum for RMSprop optimizer. |
|
epsilon: epsilon value for RMSprop optimizer, help with numerical stability. |
|
centered: Whether to normalize gradients or not. |
|
""" |
|
name: str = "RMSprop" |
|
learning_rate: float = 0.001 |
|
rho: float = 0.9 |
|
momentum: float = 0.0 |
|
epsilon: float = 1e-7 |
|
centered: bool = False |
|
|
|
|
|
@dataclasses.dataclass |
|
class AdamConfig(base_config.Config): |
|
"""Configuration for Adam optimizer. |
|
|
|
The attributes for this class matches the arguments of |
|
tf.keras.optimizer.Adam. |
|
|
|
Attributes: |
|
name: name of the optimizer. |
|
learning_rate: learning_rate for Adam optimizer. |
|
beta_1: decay rate for 1st order moments. |
|
beta_2: decay rate for 2st order moments. |
|
epsilon: epsilon value used for numerical stability in Adam optimizer. |
|
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from |
|
the paper "On the Convergence of Adam and beyond". |
|
""" |
|
name: str = "Adam" |
|
learning_rate: float = 0.001 |
|
beta_1: float = 0.9 |
|
beta_2: float = 0.999 |
|
epsilon: float = 1e-07 |
|
amsgrad: bool = False |
|
|
|
|
|
@dataclasses.dataclass |
|
class AdamWeightDecayConfig(base_config.Config): |
|
"""Configuration for Adam optimizer with weight decay. |
|
|
|
Attributes: |
|
name: name of the optimizer. |
|
learning_rate: learning_rate for the optimizer. |
|
beta_1: decay rate for 1st order moments. |
|
beta_2: decay rate for 2st order moments. |
|
epsilon: epsilon value used for numerical stability in the optimizer. |
|
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from |
|
the paper "On the Convergence of Adam and beyond". |
|
weight_decay_rate: float. Weight decay rate. Default to 0. |
|
include_in_weight_decay: list[str], or None. List of weight names to include |
|
in weight decay. |
|
include_in_weight_decay: list[str], or None. List of weight names to not |
|
include in weight decay. |
|
""" |
|
name: str = "AdamWeightDecay" |
|
learning_rate: float = 0.001 |
|
beta_1: float = 0.9 |
|
beta_2: float = 0.999 |
|
epsilon: float = 1e-07 |
|
amsgrad: bool = False |
|
weight_decay_rate: float = 0.0 |
|
include_in_weight_decay: Optional[List[str]] = None |
|
exclude_from_weight_decay: Optional[List[str]] = None |
|
|
|
|
|
@dataclasses.dataclass |
|
class LAMBConfig(base_config.Config): |
|
"""Configuration for LAMB optimizer. |
|
|
|
The attributes for this class matches the arguments of |
|
tensorflow_addons.optimizers.LAMB. |
|
|
|
Attributes: |
|
name: name of the optimizer. |
|
learning_rate: learning_rate for Adam optimizer. |
|
beta_1: decay rate for 1st order moments. |
|
beta_2: decay rate for 2st order moments. |
|
epsilon: epsilon value used for numerical stability in LAMB optimizer. |
|
weight_decay_rate: float. Weight decay rate. Default to 0. |
|
exclude_from_weight_decay: List of regex patterns of variables excluded from |
|
weight decay. Variables whose name contain a |
|
substring matching the pattern will be excluded. |
|
exclude_from_layer_adaptation: List of regex patterns of variables excluded |
|
from layer adaptation. Variables whose name |
|
contain a substring matching the pattern will |
|
be excluded. |
|
""" |
|
name: str = "LAMB" |
|
learning_rate: float = 0.001 |
|
beta_1: float = 0.9 |
|
beta_2: float = 0.999 |
|
epsilon: float = 1e-6 |
|
weight_decay_rate: float = 0.0 |
|
exclude_from_weight_decay: Optional[List[str]] = None |
|
exclude_from_layer_adaptation: Optional[List[str]] = None |
|
|