NCTCMumbai's picture
Upload 2583 files
97b6013 verified
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for common.schedules."""
from math import exp
from math import sqrt
import numpy as np
from six.moves import xrange
import tensorflow as tf
from common import config_lib # brain coder
from common import schedules # brain coder
class SchedulesTest(tf.test.TestCase):
def ScheduleTestHelper(self, config, schedule_subtype, io_values):
"""Run common checks for schedules.
Args:
config: Config object which is passed into schedules.make_schedule.
schedule_subtype: The expected schedule type to be instantiated.
io_values: List of (input, output) pairs. Must be in ascending input
order. No duplicate inputs.
"""
# Check that make_schedule makes the correct type.
f = schedules.make_schedule(config)
self.assertTrue(isinstance(f, schedule_subtype))
# Check that multiple instances returned from make_schedule behave the same.
fns = [schedules.make_schedule(config) for _ in xrange(3)]
# Check that all the inputs map to the right outputs.
for i, o in io_values:
for f in fns:
f_out = f(i)
self.assertTrue(
np.isclose(o, f_out),
'Wrong value at input %d. Expected %s, got %s' % (i, o, f_out))
# Check that a subset of the io_values are still correct.
f = schedules.make_schedule(config)
subseq = [io_values[i**2] for i in xrange(int(sqrt(len(io_values))))]
if subseq[-1] != io_values[-1]:
subseq.append(io_values[-1])
for i, o in subseq:
f_out = f(i)
self.assertTrue(
np.isclose(o, f_out),
'Wrong value at input %d. Expected %s, got %s' % (i, o, f_out))
# Check duplicate calls.
f = schedules.make_schedule(config)
for i, o in io_values:
for _ in xrange(3):
f_out = f(i)
self.assertTrue(
np.isclose(o, f_out),
'Duplicate calls at input %d are not equal. Expected %s, got %s'
% (i, o, f_out))
def testConstSchedule(self):
self.ScheduleTestHelper(
config_lib.Config(fn='const', const=5),
schedules.ConstSchedule,
[(0, 5), (1, 5), (10, 5), (20, 5), (100, 5), (1000000, 5)])
def testLinearDecaySchedule(self):
self.ScheduleTestHelper(
config_lib.Config(fn='linear_decay', initial=2, final=0, start_time=10,
end_time=20),
schedules.LinearDecaySchedule,
[(0, 2), (1, 2), (10, 2), (11, 1.8), (15, 1), (19, 0.2), (20, 0),
(100000, 0)])
# Test step function.
self.ScheduleTestHelper(
config_lib.Config(fn='linear_decay', initial=2, final=0, start_time=10,
end_time=10),
schedules.LinearDecaySchedule,
[(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)])
def testExponentialDecaySchedule(self):
self.ScheduleTestHelper(
config_lib.Config(fn='exp_decay', initial=exp(-1), final=exp(-6),
start_time=10, end_time=20),
schedules.ExponentialDecaySchedule,
[(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-1/2. - 1)),
(15, exp(-5/2. - 1)), (19, exp(-9/2. - 1)), (20, exp(-6)),
(100000, exp(-6))])
# Test step function.
self.ScheduleTestHelper(
config_lib.Config(fn='exp_decay', initial=exp(-1), final=exp(-6),
start_time=10, end_time=10),
schedules.ExponentialDecaySchedule,
[(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-6)),
(15, exp(-6))])
def testSmootherstepDecaySchedule(self):
self.ScheduleTestHelper(
config_lib.Config(fn='smooth_decay', initial=2, final=0, start_time=10,
end_time=20),
schedules.SmootherstepDecaySchedule,
[(0, 2), (1, 2), (10, 2), (11, 1.98288), (15, 1), (19, 0.01712),
(20, 0), (100000, 0)])
# Test step function.
self.ScheduleTestHelper(
config_lib.Config(fn='smooth_decay', initial=2, final=0, start_time=10,
end_time=10),
schedules.SmootherstepDecaySchedule,
[(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)])
def testHardOscillatorSchedule(self):
self.ScheduleTestHelper(
config_lib.Config(fn='hard_osc', high=2, low=0, start_time=100,
period=10, transition_fraction=0.5),
schedules.HardOscillatorSchedule,
[(0, 2), (1, 2), (10, 2), (100, 2), (101, 1.2), (102, 0.4), (103, 0),
(104, 0), (105, 0), (106, 0.8), (107, 1.6), (108, 2), (109, 2),
(110, 2), (111, 1.2), (112, 0.4), (115, 0), (116, 0.8), (119, 2),
(120, 2), (100001, 1.2), (100002, 0.4), (100005, 0), (100006, 0.8),
(100010, 2)])
# Test instantaneous step.
self.ScheduleTestHelper(
config_lib.Config(fn='hard_osc', high=2, low=0, start_time=100,
period=10, transition_fraction=0),
schedules.HardOscillatorSchedule,
[(0, 2), (1, 2), (10, 2), (99, 2), (100, 0), (104, 0), (105, 2),
(106, 2), (109, 2), (110, 0)])
if __name__ == '__main__':
tf.test.main()