File size: 13,681 Bytes
9375c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates. --><head><title>dlib C++ Library - solvers_abstract.h</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2015 Davis E. King ([email protected])
</font><font color='#009900'>// License: Boost Software License See LICENSE.txt for the full license.
</font><font color='#0000FF'>#undef</font> DLIB_DNn_SOLVERS_ABSTRACT_H_
<font color='#0000FF'>#ifdef</font> DLIB_DNn_SOLVERS_ABSTRACT_H_
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../cuda/tensor_abstract.h.html'>../cuda/tensor_abstract.h</a>"
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font>
<font color='#0000FF'>namespace</font> dlib
<b>{</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font><font color='#009900'>// ----------------------------------------------------------------------------------------
</font><font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>class</font> <b><a name='EXAMPLE_SOLVER'></a>EXAMPLE_SOLVER</b>
<b>{</b>
<font color='#009900'>/*!
WHAT THIS OBJECT REPRESENTS
A solver defines the parameter update rule for a single layer in a deep
neural network. It takes a parameter gradient vector and the layer's
parameters and tells you how the parameters should be updated.
Importantly, each solver instance is used with only one layer in a network.
This allows us to define solvers that have per layer state, for example, a
solver may keep a momentum term and apply it to its update rule.
Note that there is no dlib::EXAMPLE_SOLVER type. It is shown here purely
to document the interface a solver object must implement.
!*/</font>
<font color='#0000FF'>public</font>:
<b><a name='EXAMPLE_SOLVER'></a>EXAMPLE_SOLVER</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> layer_type<font color='#5555FF'>></font>
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> learning_rate,
<font color='#0000FF'>const</font> layer_type<font color='#5555FF'>&</font> l,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> params_grad
<font face='Lucida Console'>)</font>
<font color='#009900'>/*!
requires
- l.get_layer_params().size() != 0
- have_same_dimensions(l.get_layer_params(), params_grad) == true.
- When this function is invoked on a particular solver instance, it is
always supplied with the same layer instance, l. That is, the solver is
allowed to remember things from one invocation to another and to assume
that it is being serially applied to optimize the same layer's
parameters.
ensures
- Returns a step vector V that is intended to be used to update the
parameters by adding V to l.get_layer_params().
- This function will use the given "learning rate" to compute V. How the
learning rate is used is solver dependent. But in general the learning
rate should be used to select the step size, i.e. to somehow determine
the magnitude of V.
!*/</font>
<b>}</b>;
<font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> EXAMPLE_SOLVER<font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>EXAMPLE_SOLVER<font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
provides serialization support
!*/</font>
std::ostream<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'><</font><font color='#5555FF'><</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&</font> out, <font color='#0000FF'>const</font> EXAMPLE_SOLVER<font color='#5555FF'>&</font> item<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
Prints the solver's name and parameters to out.
!*/</font>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font><font color='#009900'>// ----------------------------------------------------------------------------------------
</font><font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>class</font> <b><a name='sgd'></a>sgd</b>
<b>{</b>
<font color='#009900'>/*!
WHAT THIS OBJECT REPRESENTS
This object implements the EXAMPLE_SOLVER interface defined above. It is a
basic stochastic gradient descent solver which uses momentum and weight
decay. In particular, it computes the update vector V according to:
V = momentum*V - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad;
Here V is a momentum term that is remembered by the solver from one
invocation of operator() to the next.
Note that the actual learning rate and weight decay used by the solver are
multiplied by the per layer multipliers. That is, the solver will call
get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and
multiply these values with the nominal learning rate and weight decay,
respectively, to determine the values it will use during each step. It is
also overloaded to allow additional learning rate multipliers to be applied
to fc_ and con_ bias parameters.
!*/</font>
<font color='#0000FF'>public</font>:
<b><a name='sgd'></a>sgd</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
ensures
- #get_weight_decay() == 0.0005
- #get_momentum() == 0.9
!*/</font>
<font color='#0000FF'>explicit</font> <b><a name='sgd'></a>sgd</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>float</u></font> weight_decay,
<font color='#0000FF'><u>float</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>
<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
requires
- weight_decay >= 0
- momentum >= 0
ensures
- #get_weight_decay() == weight_decay
- #get_momentum() == momentum
!*/</font>
<font color='#0000FF'><u>float</u></font> <b><a name='get_weight_decay'></a>get_weight_decay</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
<font color='#0000FF'><u>float</u></font> <b><a name='get_momentum'></a>get_momentum</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
<b>}</b>;
<font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> sgd<font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>sgd<font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
provides serialization support
!*/</font>
std::ostream<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'><</font><font color='#5555FF'><</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&</font> out, <font color='#0000FF'>const</font> sgd<font color='#5555FF'>&</font> item<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
Prints the solver's name and parameters to out.
!*/</font>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>class</font> <b><a name='adam'></a>adam</b>
<b>{</b>
<font color='#009900'>/*!
WHAT THIS OBJECT REPRESENTS
This object implements the EXAMPLE_SOLVER interface defined above. In
particular, it implements the ADAM parameter update method described in the
paper:
Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
optimization." International Conference on Learning Representation. 2015.
Note that the actual learning rate and weight decay used by the solver are
multiplied by the per layer multipliers. That is, the solver will call
get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and
multiply these values with the nominal learning rate and weight decay,
respectively, to determine the values it will use during each step. It is
also overloaded to allow additional learning rate multipliers to be applied
to fc_ and con_ bias parameters.
!*/</font>
<font color='#0000FF'>public</font>:
<b><a name='adam'></a>adam</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
ensures
- #get_weight_decay() == 0.0005
- #get_momentum1() == 0.9
- #get_momentum2() == 0.999
!*/</font>
<b><a name='adam'></a>adam</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>float</u></font> weight_decay,
<font color='#0000FF'><u>float</u></font> momentum1,
<font color='#0000FF'><u>float</u></font> momentum2
<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
requires
- weight_decay >= 0
- 0 <= momentum1 < 1
- 0 <= momentum2 < 1
ensures
- #get_weight_decay() == weight_decay
- #get_momentum1() == momentum1
- #get_momentum2() == momentum2
!*/</font>
<font color='#0000FF'><u>float</u></font> <b><a name='get_weight_decay'></a>get_weight_decay</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
<font color='#0000FF'><u>float</u></font> <b><a name='get_momentum1'></a>get_momentum1</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
<font color='#0000FF'><u>float</u></font> <b><a name='get_momentum2'></a>get_momentum2</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
<b>}</b>;
<font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> adam<font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>adam<font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
provides serialization support
!*/</font>
std::ostream<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'><</font><font color='#5555FF'><</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&</font> out, <font color='#0000FF'>const</font> adam<font color='#5555FF'>&</font> item<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
Prints the solver's name and parameters to out.
!*/</font>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<b>}</b>
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_DNn_SOLVERS_ABSTRACT_H_
</font>
</pre></body></html> |