|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
"""Tests for common.schedules.""" |
|
|
|
from math import exp |
|
from math import sqrt |
|
import numpy as np |
|
from six.moves import xrange |
|
import tensorflow as tf |
|
|
|
from common import config_lib |
|
from common import schedules |
|
|
|
|
|
class SchedulesTest(tf.test.TestCase): |
|
|
|
def ScheduleTestHelper(self, config, schedule_subtype, io_values): |
|
"""Run common checks for schedules. |
|
|
|
Args: |
|
config: Config object which is passed into schedules.make_schedule. |
|
schedule_subtype: The expected schedule type to be instantiated. |
|
io_values: List of (input, output) pairs. Must be in ascending input |
|
order. No duplicate inputs. |
|
""" |
|
|
|
|
|
f = schedules.make_schedule(config) |
|
self.assertTrue(isinstance(f, schedule_subtype)) |
|
|
|
|
|
fns = [schedules.make_schedule(config) for _ in xrange(3)] |
|
|
|
|
|
for i, o in io_values: |
|
for f in fns: |
|
f_out = f(i) |
|
self.assertTrue( |
|
np.isclose(o, f_out), |
|
'Wrong value at input %d. Expected %s, got %s' % (i, o, f_out)) |
|
|
|
|
|
f = schedules.make_schedule(config) |
|
subseq = [io_values[i**2] for i in xrange(int(sqrt(len(io_values))))] |
|
if subseq[-1] != io_values[-1]: |
|
subseq.append(io_values[-1]) |
|
for i, o in subseq: |
|
f_out = f(i) |
|
self.assertTrue( |
|
np.isclose(o, f_out), |
|
'Wrong value at input %d. Expected %s, got %s' % (i, o, f_out)) |
|
|
|
|
|
f = schedules.make_schedule(config) |
|
for i, o in io_values: |
|
for _ in xrange(3): |
|
f_out = f(i) |
|
self.assertTrue( |
|
np.isclose(o, f_out), |
|
'Duplicate calls at input %d are not equal. Expected %s, got %s' |
|
% (i, o, f_out)) |
|
|
|
def testConstSchedule(self): |
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='const', const=5), |
|
schedules.ConstSchedule, |
|
[(0, 5), (1, 5), (10, 5), (20, 5), (100, 5), (1000000, 5)]) |
|
|
|
def testLinearDecaySchedule(self): |
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='linear_decay', initial=2, final=0, start_time=10, |
|
end_time=20), |
|
schedules.LinearDecaySchedule, |
|
[(0, 2), (1, 2), (10, 2), (11, 1.8), (15, 1), (19, 0.2), (20, 0), |
|
(100000, 0)]) |
|
|
|
|
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='linear_decay', initial=2, final=0, start_time=10, |
|
end_time=10), |
|
schedules.LinearDecaySchedule, |
|
[(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)]) |
|
|
|
def testExponentialDecaySchedule(self): |
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='exp_decay', initial=exp(-1), final=exp(-6), |
|
start_time=10, end_time=20), |
|
schedules.ExponentialDecaySchedule, |
|
[(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-1/2. - 1)), |
|
(15, exp(-5/2. - 1)), (19, exp(-9/2. - 1)), (20, exp(-6)), |
|
(100000, exp(-6))]) |
|
|
|
|
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='exp_decay', initial=exp(-1), final=exp(-6), |
|
start_time=10, end_time=10), |
|
schedules.ExponentialDecaySchedule, |
|
[(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-6)), |
|
(15, exp(-6))]) |
|
|
|
def testSmootherstepDecaySchedule(self): |
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='smooth_decay', initial=2, final=0, start_time=10, |
|
end_time=20), |
|
schedules.SmootherstepDecaySchedule, |
|
[(0, 2), (1, 2), (10, 2), (11, 1.98288), (15, 1), (19, 0.01712), |
|
(20, 0), (100000, 0)]) |
|
|
|
|
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='smooth_decay', initial=2, final=0, start_time=10, |
|
end_time=10), |
|
schedules.SmootherstepDecaySchedule, |
|
[(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)]) |
|
|
|
def testHardOscillatorSchedule(self): |
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='hard_osc', high=2, low=0, start_time=100, |
|
period=10, transition_fraction=0.5), |
|
schedules.HardOscillatorSchedule, |
|
[(0, 2), (1, 2), (10, 2), (100, 2), (101, 1.2), (102, 0.4), (103, 0), |
|
(104, 0), (105, 0), (106, 0.8), (107, 1.6), (108, 2), (109, 2), |
|
(110, 2), (111, 1.2), (112, 0.4), (115, 0), (116, 0.8), (119, 2), |
|
(120, 2), (100001, 1.2), (100002, 0.4), (100005, 0), (100006, 0.8), |
|
(100010, 2)]) |
|
|
|
|
|
self.ScheduleTestHelper( |
|
config_lib.Config(fn='hard_osc', high=2, low=0, start_time=100, |
|
period=10, transition_fraction=0), |
|
schedules.HardOscillatorSchedule, |
|
[(0, 2), (1, 2), (10, 2), (99, 2), (100, 0), (104, 0), (105, 2), |
|
(106, 2), (109, 2), (110, 0)]) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|