|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
"""Schedule functions for controlling hparams over time.""" |
|
|
|
from abc import ABCMeta |
|
from abc import abstractmethod |
|
import math |
|
|
|
from common import config_lib |
|
|
|
|
|
class Schedule(object): |
|
"""Schedule is a function which sets a hyperparameter's value over time. |
|
|
|
For example, a schedule can be used to decay an hparams, or oscillate it over |
|
time. |
|
|
|
This object is constructed with an instance of config_lib.Config (will be |
|
specific to each class implementation). For example if this is a decay |
|
schedule, the config may specify the rate of decay and decay start time. Then |
|
the object instance is called like a function, mapping global step (an integer |
|
counting how many calls to the train op have been made) to the hparam value. |
|
|
|
Properties of a schedule function f(t): |
|
0) Domain of t is the non-negative integers (t may be 0). |
|
1) Range of f is the reals. |
|
2) Schedule functions can assume that they will be called in time order. This |
|
allows schedules to be stateful. |
|
3) Schedule functions should be deterministic. Two schedule instances with the |
|
same config must always give the same value for each t, and regardless of |
|
what t's it was previously called on. Users may call f(t) on arbitrary |
|
(positive) time jumps. Essentially, multiple schedule instances used in |
|
replica training will behave the same. |
|
4) Duplicate successive calls on the same time are allowed. |
|
""" |
|
__metaclass__ = ABCMeta |
|
|
|
@abstractmethod |
|
def __init__(self, config): |
|
"""Construct this schedule with a config specific to each class impl. |
|
|
|
Args: |
|
config: An instance of config_lib.Config. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def __call__(self, global_step): |
|
"""Map `global_step` to a value. |
|
|
|
`global_step` is an integer counting how many calls to the train op have |
|
been made across all replicas (hence why it is global). Implementations |
|
may assume calls to be made in time order, i.e. `global_step` now >= |
|
previous `global_step` values. |
|
|
|
Args: |
|
global_step: Non-negative integer. |
|
|
|
Returns: |
|
Hparam value at this step. A number. |
|
""" |
|
pass |
|
|
|
|
|
class ConstSchedule(Schedule): |
|
"""Constant function. |
|
|
|
config: |
|
const: Constant value at every step. |
|
|
|
f(t) = const. |
|
""" |
|
|
|
def __init__(self, config): |
|
super(ConstSchedule, self).__init__(config) |
|
self.const = config.const |
|
|
|
def __call__(self, global_step): |
|
return self.const |
|
|
|
|
|
class LinearDecaySchedule(Schedule): |
|
"""Linear decay function. |
|
|
|
config: |
|
initial: Decay starts from this value. |
|
final: Decay ends at this value. |
|
start_time: Step when decay starts. Constant before it. |
|
end_time: When decay ends. Constant after it. |
|
|
|
f(t) is a linear function when start_time <= t <= end_time, with slope of |
|
(final - initial) / (end_time - start_time). f(t) = initial |
|
when t <= start_time. f(t) = final when t >= end_time. |
|
|
|
If start_time == end_time, this becomes a step function. |
|
""" |
|
|
|
def __init__(self, config): |
|
super(LinearDecaySchedule, self).__init__(config) |
|
self.initial = config.initial |
|
self.final = config.final |
|
self.start_time = config.start_time |
|
self.end_time = config.end_time |
|
|
|
if self.end_time < self.start_time: |
|
raise ValueError('start_time must be before end_time.') |
|
|
|
|
|
self._time_diff = float(self.end_time - self.start_time) |
|
self._diff = float(self.final - self.initial) |
|
self._slope = ( |
|
self._diff / self._time_diff if self._time_diff > 0 else float('inf')) |
|
|
|
def __call__(self, global_step): |
|
if global_step <= self.start_time: |
|
return self.initial |
|
if global_step > self.end_time: |
|
return self.final |
|
return self.initial + (global_step - self.start_time) * self._slope |
|
|
|
|
|
class ExponentialDecaySchedule(Schedule): |
|
"""Exponential decay function. |
|
|
|
See https://en.wikipedia.org/wiki/Exponential_decay. |
|
|
|
Use this decay function to decay over orders of magnitude. For example, to |
|
decay learning rate from 1e-2 to 1e-6. Exponential decay will decay the |
|
exponent linearly. |
|
|
|
config: |
|
initial: Decay starts from this value. |
|
final: Decay ends at this value. |
|
start_time: Step when decay starts. Constant before it. |
|
end_time: When decay ends. Constant after it. |
|
|
|
f(t) is an exponential decay function when start_time <= t <= end_time. The |
|
decay rate and amplitude are chosen so that f(t) = initial when |
|
t = start_time, and f(t) = final when t = end_time. f(t) is constant for |
|
t < start_time or t > end_time. initial and final must be positive values. |
|
|
|
If start_time == end_time, this becomes a step function. |
|
""" |
|
|
|
def __init__(self, config): |
|
super(ExponentialDecaySchedule, self).__init__(config) |
|
self.initial = config.initial |
|
self.final = config.final |
|
self.start_time = config.start_time |
|
self.end_time = config.end_time |
|
|
|
if self.initial <= 0 or self.final <= 0: |
|
raise ValueError('initial and final must be positive numbers.') |
|
|
|
|
|
self._linear_fn = LinearDecaySchedule( |
|
config_lib.Config( |
|
initial=math.log(self.initial), |
|
final=math.log(self.final), |
|
start_time=self.start_time, |
|
end_time=self.end_time)) |
|
|
|
def __call__(self, global_step): |
|
return math.exp(self._linear_fn(global_step)) |
|
|
|
|
|
class SmootherstepDecaySchedule(Schedule): |
|
"""Smootherstep decay function. |
|
|
|
A sigmoidal like transition from initial to final values. A smoother |
|
transition than linear and exponential decays, hence the name. |
|
See https://en.wikipedia.org/wiki/Smoothstep. |
|
|
|
config: |
|
initial: Decay starts from this value. |
|
final: Decay ends at this value. |
|
start_time: Step when decay starts. Constant before it. |
|
end_time: When decay ends. Constant after it. |
|
|
|
f(t) is fully defined here: |
|
https://en.wikipedia.org/wiki/Smoothstep#Variations. |
|
|
|
f(t) is smooth, as in its first-derivative exists everywhere. |
|
""" |
|
|
|
def __init__(self, config): |
|
super(SmootherstepDecaySchedule, self).__init__(config) |
|
self.initial = config.initial |
|
self.final = config.final |
|
self.start_time = config.start_time |
|
self.end_time = config.end_time |
|
|
|
if self.end_time < self.start_time: |
|
raise ValueError('start_time must be before end_time.') |
|
|
|
self._time_diff = float(self.end_time - self.start_time) |
|
self._diff = float(self.final - self.initial) |
|
|
|
def __call__(self, global_step): |
|
if global_step <= self.start_time: |
|
return self.initial |
|
if global_step > self.end_time: |
|
return self.final |
|
x = (global_step - self.start_time) / self._time_diff |
|
|
|
|
|
return self.initial + x * x * x * (x * (x * 6 - 15) + 10) * self._diff |
|
|
|
|
|
class HardOscillatorSchedule(Schedule): |
|
"""Hard oscillator function. |
|
|
|
config: |
|
high: Max value of the oscillator. Value at constant plateaus. |
|
low: Min value of the oscillator. Value at constant valleys. |
|
start_time: Global step when oscillation starts. Constant before this. |
|
period: Width of one oscillation, i.e. number of steps over which the |
|
oscillation takes place. |
|
transition_fraction: Fraction of the period spent transitioning between high |
|
and low values. 50% of this time is spent rising, and 50% of this time |
|
is spent falling. 50% of the remaining time is spent constant at the |
|
high value, and 50% of the remaining time is spent constant at the low |
|
value. transition_fraction = 1.0 means the entire period is spent |
|
rising and falling. transition_fraction = 0.0 means no time is spent |
|
rising and falling, i.e. the function jumps instantaneously between |
|
high and low. |
|
|
|
f(t) = high when t < start_time. |
|
f(t) is periodic when t >= start_time, with f(t + period) = f(t). |
|
f(t) is linear with positive slope when rising, and negative slope when |
|
falling. At the start of the period t0, f(t0) = high and begins to descend. |
|
At the middle of the period f is low and is constant until the ascension |
|
begins. f then rises from low to high and is constant again until the period |
|
repeats. |
|
|
|
Note: when transition_fraction is 0, f starts the period low and ends high. |
|
""" |
|
|
|
def __init__(self, config): |
|
super(HardOscillatorSchedule, self).__init__(config) |
|
self.high = config.high |
|
self.low = config.low |
|
self.start_time = config.start_time |
|
self.period = float(config.period) |
|
self.transition_fraction = config.transition_fraction |
|
self.half_transition_fraction = config.transition_fraction / 2.0 |
|
|
|
if self.transition_fraction < 0 or self.transition_fraction > 1.0: |
|
raise ValueError('transition_fraction must be between 0 and 1.0') |
|
if self.period <= 0: |
|
raise ValueError('period must be positive') |
|
|
|
self._slope = ( |
|
float(self.high - self.low) / self.half_transition_fraction |
|
if self.half_transition_fraction > 0 else float('inf')) |
|
|
|
def __call__(self, global_step): |
|
if global_step < self.start_time: |
|
return self.high |
|
period_pos = ((global_step - self.start_time) / self.period) % 1.0 |
|
if period_pos >= 0.5: |
|
|
|
period_pos -= 0.5 |
|
if period_pos < self.half_transition_fraction: |
|
return self.low + period_pos * self._slope |
|
else: |
|
return self.high |
|
else: |
|
|
|
if period_pos < self.half_transition_fraction: |
|
return self.high - period_pos * self._slope |
|
else: |
|
return self.low |
|
|
|
|
|
_NAME_TO_CONFIG = { |
|
'const': ConstSchedule, |
|
'linear_decay': LinearDecaySchedule, |
|
'exp_decay': ExponentialDecaySchedule, |
|
'smooth_decay': SmootherstepDecaySchedule, |
|
'hard_osc': HardOscillatorSchedule, |
|
} |
|
|
|
|
|
def make_schedule(config): |
|
"""Schedule factory. |
|
|
|
Given `config` containing a `fn` property, a Schedule implementation is |
|
instantiated with `config`. See `_NAME_TO_CONFIG` for `fn` options. |
|
|
|
Args: |
|
config: Config with a `fn` option that specifies which Schedule |
|
implementation to use. `config` is passed into the constructor. |
|
|
|
Returns: |
|
A Schedule impl instance. |
|
""" |
|
schedule_class = _NAME_TO_CONFIG[config.fn] |
|
return schedule_class(config) |
|
|