File size: 10,109 Bytes
9375c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - lspi_abstract.h</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2015  Davis E. King ([email protected])
</font><font color='#009900'>// License: Boost Software License   See LICENSE.txt for the full license.
</font><font color='#0000FF'>#undef</font> DLIB_LSPI_ABSTRACT_Hh_
<font color='#0000FF'>#ifdef</font> DLIB_LSPI_ABSTRACT_Hh_

<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='approximate_linear_models_abstract.h.html'>approximate_linear_models_abstract.h</a>"

<font color='#0000FF'>namespace</font> dlib
<b>{</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
        <font color='#0000FF'>typename</font> feature_extractor
        <font color='#5555FF'>&gt;</font>
    <font color='#0000FF'>class</font> <b><a name='lspi'></a>lspi</b>
    <b>{</b>
        <font color='#009900'>/*!
            REQUIREMENTS ON feature_extractor
                feature_extractor should implement the example_feature_extractor interface
                defined at the top of dlib/control/approximate_linear_models_abstract.h

            WHAT THIS OBJECT REPRESENTS
                This object is an implementation of the reinforcement learning algorithm
                described in the following paper:
                    Lagoudakis, Michail G., and Ronald Parr. "Least-squares policy
                    iteration." The Journal of Machine Learning Research 4 (2003):
                    1107-1149.
                
                This means that it takes a bunch of training data in the form of
                process_samples and outputs a policy that hopefully performs well when run
                on the process that generated those samples.
        !*/</font>

    <font color='#0000FF'>public</font>:
        <font color='#0000FF'>typedef</font> feature_extractor feature_extractor_type;
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> feature_extractor::state_type state_type;
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> feature_extractor::action_type action_type;

        <font color='#0000FF'>explicit</font> <b><a name='lspi'></a>lspi</b><font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> feature_extractor<font color='#5555FF'>&amp;</font> fe_
        <font face='Lucida Console'>)</font>; 
        <font color='#009900'>/*!
            ensures
                - #get_feature_extractor() == fe_
                - #get_lambda() == 0.01
                - #get_discount == 0.8
                - #get_epsilon() == 0.01
                - is not verbose
                - #get_max_iterations() == 100
        !*/</font>

        <b><a name='lspi'></a>lspi</b><font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>;
        <font color='#009900'>/*!
            ensures
                - #get_feature_extractor() == feature_extractor() 
                  (i.e. it will have its default value)
                - #get_lambda() == 0.01
                - #get_discount == 0.8
                - #get_epsilon() == 0.01
                - is not verbose
                - #get_max_iterations() == 100
        !*/</font>

        <font color='#0000FF'><u>double</u></font> <b><a name='get_discount'></a>get_discount</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
        <font color='#009900'>/*!
            ensures
                - returns the discount applied to the sum of rewards in the Bellman
                  equation.
        !*/</font>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_discount'></a>set_discount</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>double</u></font> value
        <font face='Lucida Console'>)</font>;
        <font color='#009900'>/*!
            requires
                - 0 &lt; value &lt;= 1
            ensures
                - #get_discount() == value
        !*/</font>

        <font color='#0000FF'>const</font> feature_extractor<font color='#5555FF'>&amp;</font> <b><a name='get_feature_extractor'></a>get_feature_extractor</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
        <font color='#009900'>/*!
            ensures
                - returns the feature extractor used by this object
        !*/</font>

        <font color='#0000FF'><u>void</u></font> <b><a name='be_verbose'></a>be_verbose</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>;
        <font color='#009900'>/*!
            ensures
                - This object will print status messages to standard out so that a 
                  user can observe the progress of the algorithm.
        !*/</font>

        <font color='#0000FF'><u>void</u></font> <b><a name='be_quiet'></a>be_quiet</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>;
        <font color='#009900'>/*!
            ensures
                - this object will not print anything to standard out
        !*/</font>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_epsilon'></a>set_epsilon</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>double</u></font> eps
        <font face='Lucida Console'>)</font>;
        <font color='#009900'>/*!
            requires
                - eps &gt; 0
            ensures
                - #get_epsilon() == eps
        !*/</font>

        <font color='#0000FF'><u>double</u></font> <b><a name='get_epsilon'></a>get_epsilon</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
        <font color='#009900'>/*!
            ensures
                - returns the error epsilon that determines when training should stop.
                  Smaller values may result in a more accurate solution but take longer to
                  train.  
        !*/</font>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_lambda'></a>set_lambda</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>double</u></font> lambda_ 
        <font face='Lucida Console'>)</font>; 
        <font color='#009900'>/*!
            requires
                - lambda &gt;= 0
            ensures
                - #get_lambda() == lambda 
        !*/</font>

        <font color='#0000FF'><u>double</u></font> <b><a name='get_lambda'></a>get_lambda</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
        <font color='#009900'>/*!
            ensures
                - returns the regularization parameter.  It is the parameter that 
                  determines the trade off between trying to fit the training data 
                  exactly or allowing more errors but hopefully improving the 
                  generalization ability of the resulting function.  Smaller values 
                  encourage exact fitting while larger values of lambda may encourage 
                  better generalization. 
        !*/</font>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_max_iterations'></a>set_max_iterations</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> max_iter
        <font face='Lucida Console'>)</font>; 
        <font color='#009900'>/*!
            ensures
                - #get_max_iterations() == max_iter
        !*/</font>

        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_max_iterations'></a>get_max_iterations</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font>; 
        <font color='#009900'>/*!
            ensures
                - returns the maximum number of iterations the SVM optimizer is allowed to
                  run before it is required to stop and return a result.
        !*/</font>

        <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
            <font color='#0000FF'>typename</font> vector_type
            <font color='#5555FF'>&gt;</font>
        policy<font color='#5555FF'>&lt;</font>feature_extractor<font color='#5555FF'>&gt;</font> <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> vector_type<font color='#5555FF'>&amp;</font> samples
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>;
        <font color='#009900'>/*!
            requires
                - samples.size() &gt; 0
                - samples is something with an interface that looks like 
                  std::vector&lt;process_sample&lt;feature_extractor&gt;&gt;.  That is, it should
                  be some kind of array of process_sample objects.
            ensures
                - Trains a policy based on the given data and returns the results.  The
                  idea is to find a policy that will obtain the largest possible reward
                  when run on the process that generated the samples.  In particular, 
                  if the returned policy is P then:
                    - P(S) == the best action to take when in state S.
                    - if (feature_extractor::force_last_weight_to_1) then
                        - The last element of P.get_weights() is 1. 
        !*/</font>

    <b>}</b>;

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<b>}</b>

<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_LSPI_ABSTRACT_Hh_
</font>


</pre></body></html>