File size: 13,681 Bytes
9375c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - solvers_abstract.h</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2015  Davis E. King ([email protected])
</font><font color='#009900'>// License: Boost Software License   See LICENSE.txt for the full license.
</font><font color='#0000FF'>#undef</font> DLIB_DNn_SOLVERS_ABSTRACT_H_
<font color='#0000FF'>#ifdef</font> DLIB_DNn_SOLVERS_ABSTRACT_H_

<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../cuda/tensor_abstract.h.html'>../cuda/tensor_abstract.h</a>"
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>

<font color='#0000FF'>namespace</font> dlib
<b>{</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font><font color='#009900'>// ----------------------------------------------------------------------------------------
</font><font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>class</font> <b><a name='EXAMPLE_SOLVER'></a>EXAMPLE_SOLVER</b> 
    <b>{</b>
        <font color='#009900'>/*!
            WHAT THIS OBJECT REPRESENTS
                A solver defines the parameter update rule for a single layer in a deep
                neural network.  It takes a parameter gradient vector and the layer's
                parameters and tells you how the parameters should be updated.
                Importantly, each solver instance is used with only one layer in a network.
                This allows us to define solvers that have per layer state, for example, a
                solver may keep a momentum term and apply it to its update rule.

                Note that there is no dlib::EXAMPLE_SOLVER type.  It is shown here purely
                to document the interface a solver object must implement.
        !*/</font>

    <font color='#0000FF'>public</font>:

        <b><a name='EXAMPLE_SOLVER'></a>EXAMPLE_SOLVER</b><font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>;

        <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> layer_type<font color='#5555FF'>&gt;</font>
        <font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> learning_rate,
            <font color='#0000FF'>const</font> layer_type<font color='#5555FF'>&amp;</font> l,
            <font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> params_grad
        <font face='Lucida Console'>)</font>
        <font color='#009900'>/*!
            requires
                - l.get_layer_params().size() != 0
                - have_same_dimensions(l.get_layer_params(), params_grad) == true.
                - When this function is invoked on a particular solver instance, it is
                  always supplied with the same layer instance, l.  That is, the solver is
                  allowed to remember things from one invocation to another and to assume
                  that it is being serially applied to optimize the same layer's
                  parameters. 
            ensures
                - Returns a step vector V that is intended to be used to update the
                  parameters by adding V to l.get_layer_params().
                - This function will use the given "learning rate" to compute V.  How the
                  learning rate is used is solver dependent.  But in general the learning
                  rate should be used to select the step size, i.e. to somehow determine
                  the magnitude of V.
        !*/</font>
    <b>}</b>;

    <font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> EXAMPLE_SOLVER<font color='#5555FF'>&amp;</font> item, std::ostream<font color='#5555FF'>&amp;</font> out<font face='Lucida Console'>)</font>;
    <font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>EXAMPLE_SOLVER<font color='#5555FF'>&amp;</font> item, std::istream<font color='#5555FF'>&amp;</font> in<font face='Lucida Console'>)</font>;
    <font color='#009900'>/*!
        provides serialization support  
    !*/</font>

    std::ostream<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&amp;</font> out, <font color='#0000FF'>const</font> EXAMPLE_SOLVER<font color='#5555FF'>&amp;</font> item<font face='Lucida Console'>)</font>;
    <font color='#009900'>/*!
        Prints the solver's name and parameters to out.
    !*/</font>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font><font color='#009900'>// ----------------------------------------------------------------------------------------
</font><font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>class</font> <b><a name='sgd'></a>sgd</b>
    <b>{</b>
        <font color='#009900'>/*!
            WHAT THIS OBJECT REPRESENTS
                This object implements the EXAMPLE_SOLVER interface defined above.  It is a
                basic stochastic gradient descent solver which uses momentum and weight
                decay.  In particular, it computes the update vector V according to:
                    V = momentum*V - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad;
                Here V is a momentum term that is remembered by the solver from one
                invocation of operator() to the next.  


                Note that the actual learning rate and weight decay used by the solver are
                multiplied by the per layer multipliers.  That is, the solver will call
                get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and
                multiply these values with the nominal learning rate and weight decay,
                respectively, to determine the values it will use during each step.  It is
                also overloaded to allow additional learning rate multipliers to be applied
                to fc_ and con_ bias parameters.
        !*/</font>
    <font color='#0000FF'>public</font>:

        <b><a name='sgd'></a>sgd</b><font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>; 
        <font color='#009900'>/*!
            ensures
                - #get_weight_decay()  == 0.0005 
                - #get_momentum()      == 0.9 
        !*/</font>

        <font color='#0000FF'>explicit</font> <b><a name='sgd'></a>sgd</b><font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>float</u></font> weight_decay,
            <font color='#0000FF'><u>float</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>
        <font face='Lucida Console'>)</font>; 
        <font color='#009900'>/*!
            requires
                - weight_decay &gt;= 0
                - momentum &gt;= 0
            ensures
                - #get_weight_decay()  == weight_decay 
                - #get_momentum()      == momentum 
        !*/</font>

        <font color='#0000FF'><u>float</u></font> <b><a name='get_weight_decay'></a>get_weight_decay</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
        <font color='#0000FF'><u>float</u></font> <b><a name='get_momentum'></a>get_momentum</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; 
    <b>}</b>;

    <font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> sgd<font color='#5555FF'>&amp;</font> item, std::ostream<font color='#5555FF'>&amp;</font> out<font face='Lucida Console'>)</font>;
    <font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>sgd<font color='#5555FF'>&amp;</font> item, std::istream<font color='#5555FF'>&amp;</font> in<font face='Lucida Console'>)</font>;
    <font color='#009900'>/*!
        provides serialization support  
    !*/</font>

    std::ostream<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&amp;</font> out, <font color='#0000FF'>const</font> sgd<font color='#5555FF'>&amp;</font> item<font face='Lucida Console'>)</font>;
    <font color='#009900'>/*!
        Prints the solver's name and parameters to out.
    !*/</font>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>class</font> <b><a name='adam'></a>adam</b>
    <b>{</b>
        <font color='#009900'>/*!
            WHAT THIS OBJECT REPRESENTS
                This object implements the EXAMPLE_SOLVER interface defined above.  In
                particular, it implements the ADAM parameter update method described in the
                paper:
                    Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
                    optimization." International Conference on Learning Representation. 2015.


                Note that the actual learning rate and weight decay used by the solver are
                multiplied by the per layer multipliers.  That is, the solver will call
                get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and
                multiply these values with the nominal learning rate and weight decay,
                respectively, to determine the values it will use during each step.  It is
                also overloaded to allow additional learning rate multipliers to be applied
                to fc_ and con_ bias parameters.
        !*/</font>

    <font color='#0000FF'>public</font>:

        <b><a name='adam'></a>adam</b><font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>; 
        <font color='#009900'>/*!
            ensures
                - #get_weight_decay()  == 0.0005 
                - #get_momentum1()     == 0.9 
                - #get_momentum2()     == 0.999 
        !*/</font>

        <b><a name='adam'></a>adam</b><font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>float</u></font> weight_decay,
            <font color='#0000FF'><u>float</u></font> momentum1, 
            <font color='#0000FF'><u>float</u></font> momentum2 
        <font face='Lucida Console'>)</font>; 
        <font color='#009900'>/*!
            requires
                - weight_decay &gt;= 0
                - 0 &lt;= momentum1 &lt; 1
                - 0 &lt;= momentum2 &lt; 1
            ensures
                - #get_weight_decay()  == weight_decay 
                - #get_momentum1()     == momentum1
                - #get_momentum2()     == momentum2
        !*/</font>

        <font color='#0000FF'><u>float</u></font> <b><a name='get_weight_decay'></a>get_weight_decay</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
        <font color='#0000FF'><u>float</u></font> <b><a name='get_momentum1'></a>get_momentum1</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; 
        <font color='#0000FF'><u>float</u></font> <b><a name='get_momentum2'></a>get_momentum2</b> <font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; 
    <b>}</b>;

    <font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> adam<font color='#5555FF'>&amp;</font> item, std::ostream<font color='#5555FF'>&amp;</font> out<font face='Lucida Console'>)</font>;
    <font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>adam<font color='#5555FF'>&amp;</font> item, std::istream<font color='#5555FF'>&amp;</font> in<font face='Lucida Console'>)</font>;
    <font color='#009900'>/*!
        provides serialization support  
    !*/</font>

    std::ostream<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&amp;</font> out, <font color='#0000FF'>const</font> adam<font color='#5555FF'>&amp;</font> item<font face='Lucida Console'>)</font>;
    <font color='#009900'>/*!
        Prints the solver's name and parameters to out.
    !*/</font>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<b>}</b>

<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_DNn_SOLVERS_ABSTRACT_H_
</font>

</pre></body></html>