|
<html><head><title>dlib C++ Library - solvers_abstract.h</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// Copyright (C) 2015 Davis E. King ([email protected]) |
|
</font><font color='#009900'>// License: Boost Software License See LICENSE.txt for the full license. |
|
</font><font color='#0000FF'>#undef</font> DLIB_DNn_SOLVERS_ABSTRACT_H_ |
|
<font color='#0000FF'>#ifdef</font> DLIB_DNn_SOLVERS_ABSTRACT_H_ |
|
|
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../cuda/tensor_abstract.h.html'>../cuda/tensor_abstract.h</a>" |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font> |
|
|
|
<font color='#0000FF'>namespace</font> dlib |
|
<b>{</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font><font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font><font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>class</font> <b><a name='EXAMPLE_SOLVER'></a>EXAMPLE_SOLVER</b> |
|
<b>{</b> |
|
<font color='#009900'>/*! |
|
WHAT THIS OBJECT REPRESENTS |
|
A solver defines the parameter update rule for a single layer in a deep |
|
neural network. It takes a parameter gradient vector and the layer's |
|
parameters and tells you how the parameters should be updated. |
|
Importantly, each solver instance is used with only one layer in a network. |
|
This allows us to define solvers that have per layer state, for example, a |
|
solver may keep a momentum term and apply it to its update rule. |
|
|
|
Note that there is no dlib::EXAMPLE_SOLVER type. It is shown here purely |
|
to document the interface a solver object must implement. |
|
!*/</font> |
|
|
|
<font color='#0000FF'>public</font>: |
|
|
|
<b><a name='EXAMPLE_SOLVER'></a>EXAMPLE_SOLVER</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> layer_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> learning_rate, |
|
<font color='#0000FF'>const</font> layer_type<font color='#5555FF'>&</font> l, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> params_grad |
|
<font face='Lucida Console'>)</font> |
|
<font color='#009900'>/*! |
|
requires |
|
- l.get_layer_params().size() != 0 |
|
- have_same_dimensions(l.get_layer_params(), params_grad) == true. |
|
- When this function is invoked on a particular solver instance, it is |
|
always supplied with the same layer instance, l. That is, the solver is |
|
allowed to remember things from one invocation to another and to assume |
|
that it is being serially applied to optimize the same layer's |
|
parameters. |
|
ensures |
|
- Returns a step vector V that is intended to be used to update the |
|
parameters by adding V to l.get_layer_params(). |
|
- This function will use the given "learning rate" to compute V. How the |
|
learning rate is used is solver dependent. But in general the learning |
|
rate should be used to select the step size, i.e. to somehow determine |
|
the magnitude of V. |
|
!*/</font> |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> EXAMPLE_SOLVER<font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>EXAMPLE_SOLVER<font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
provides serialization support |
|
!*/</font> |
|
|
|
std::ostream<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'><</font><font color='#5555FF'><</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&</font> out, <font color='#0000FF'>const</font> EXAMPLE_SOLVER<font color='#5555FF'>&</font> item<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
Prints the solver's name and parameters to out. |
|
!*/</font> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font><font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font><font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>class</font> <b><a name='sgd'></a>sgd</b> |
|
<b>{</b> |
|
<font color='#009900'>/*! |
|
WHAT THIS OBJECT REPRESENTS |
|
This object implements the EXAMPLE_SOLVER interface defined above. It is a |
|
basic stochastic gradient descent solver which uses momentum and weight |
|
decay. In particular, it computes the update vector V according to: |
|
V = momentum*V - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad; |
|
Here V is a momentum term that is remembered by the solver from one |
|
invocation of operator() to the next. |
|
|
|
|
|
Note that the actual learning rate and weight decay used by the solver are |
|
multiplied by the per layer multipliers. That is, the solver will call |
|
get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and |
|
multiply these values with the nominal learning rate and weight decay, |
|
respectively, to determine the values it will use during each step. It is |
|
also overloaded to allow additional learning rate multipliers to be applied |
|
to fc_ and con_ bias parameters. |
|
!*/</font> |
|
<font color='#0000FF'>public</font>: |
|
|
|
<b><a name='sgd'></a>sgd</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- #get_weight_decay() == 0.0005 |
|
- #get_momentum() == 0.9 |
|
!*/</font> |
|
|
|
<font color='#0000FF'>explicit</font> <b><a name='sgd'></a>sgd</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>float</u></font> weight_decay, |
|
<font color='#0000FF'><u>float</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font> |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
requires |
|
- weight_decay >= 0 |
|
- momentum >= 0 |
|
ensures |
|
- #get_weight_decay() == weight_decay |
|
- #get_momentum() == momentum |
|
!*/</font> |
|
|
|
<font color='#0000FF'><u>float</u></font> <b><a name='get_weight_decay'></a>get_weight_decay</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#0000FF'><u>float</u></font> <b><a name='get_momentum'></a>get_momentum</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> sgd<font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>sgd<font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
provides serialization support |
|
!*/</font> |
|
|
|
std::ostream<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'><</font><font color='#5555FF'><</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&</font> out, <font color='#0000FF'>const</font> sgd<font color='#5555FF'>&</font> item<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
Prints the solver's name and parameters to out. |
|
!*/</font> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>class</font> <b><a name='adam'></a>adam</b> |
|
<b>{</b> |
|
<font color='#009900'>/*! |
|
WHAT THIS OBJECT REPRESENTS |
|
This object implements the EXAMPLE_SOLVER interface defined above. In |
|
particular, it implements the ADAM parameter update method described in the |
|
paper: |
|
Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic |
|
optimization." International Conference on Learning Representation. 2015. |
|
|
|
|
|
Note that the actual learning rate and weight decay used by the solver are |
|
multiplied by the per layer multipliers. That is, the solver will call |
|
get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and |
|
multiply these values with the nominal learning rate and weight decay, |
|
respectively, to determine the values it will use during each step. It is |
|
also overloaded to allow additional learning rate multipliers to be applied |
|
to fc_ and con_ bias parameters. |
|
!*/</font> |
|
|
|
<font color='#0000FF'>public</font>: |
|
|
|
<b><a name='adam'></a>adam</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- #get_weight_decay() == 0.0005 |
|
- #get_momentum1() == 0.9 |
|
- #get_momentum2() == 0.999 |
|
!*/</font> |
|
|
|
<b><a name='adam'></a>adam</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>float</u></font> weight_decay, |
|
<font color='#0000FF'><u>float</u></font> momentum1, |
|
<font color='#0000FF'><u>float</u></font> momentum2 |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
requires |
|
- weight_decay >= 0 |
|
- 0 <= momentum1 < 1 |
|
- 0 <= momentum2 < 1 |
|
ensures |
|
- #get_weight_decay() == weight_decay |
|
- #get_momentum1() == momentum1 |
|
- #get_momentum2() == momentum2 |
|
!*/</font> |
|
|
|
<font color='#0000FF'><u>float</u></font> <b><a name='get_weight_decay'></a>get_weight_decay</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#0000FF'><u>float</u></font> <b><a name='get_momentum1'></a>get_momentum1</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#0000FF'><u>float</u></font> <b><a name='get_momentum2'></a>get_momentum2</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> adam<font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>adam<font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
provides serialization support |
|
!*/</font> |
|
|
|
std::ostream<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'><</font><font color='#5555FF'><</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&</font> out, <font color='#0000FF'>const</font> adam<font color='#5555FF'>&</font> item<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
Prints the solver's name and parameters to out. |
|
!*/</font> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_DNn_SOLVERS_ABSTRACT_H_ |
|
</font> |
|
|
|
</pre></body></html> |