|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Optimizers.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import functools |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
class OptimizerFactory(object): |
|
"""Class to generate optimizer function.""" |
|
|
|
def __init__(self, params): |
|
"""Creates optimized based on the specified flags.""" |
|
if params.type == 'momentum': |
|
self._optimizer = functools.partial( |
|
tf.keras.optimizers.SGD, |
|
momentum=params.momentum, |
|
nesterov=params.nesterov) |
|
elif params.type == 'adam': |
|
self._optimizer = tf.keras.optimizers.Adam |
|
elif params.type == 'adadelta': |
|
self._optimizer = tf.keras.optimizers.Adadelta |
|
elif params.type == 'adagrad': |
|
self._optimizer = tf.keras.optimizers.Adagrad |
|
elif params.type == 'rmsprop': |
|
self._optimizer = functools.partial( |
|
tf.keras.optimizers.RMSprop, momentum=params.momentum) |
|
else: |
|
raise ValueError('Unsupported optimizer type `{}`.'.format(params.type)) |
|
|
|
def __call__(self, learning_rate): |
|
return self._optimizer(learning_rate=learning_rate) |
|
|