File size: 3,608 Bytes
9375c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright (C) 2015  Davis E. King ([email protected])
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_

#include "approximate_linear_models_abstract.h"
#include "../matrix.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename feature_extractor
        >
    struct process_sample
    {
        typedef feature_extractor feature_extractor_type;
        typedef typename feature_extractor::state_type state_type;
        typedef typename feature_extractor::action_type action_type;

        process_sample(){}

        process_sample(
            const state_type& s,
            const action_type& a,
            const state_type& n,
            const double& r
        ) : state(s), action(a), next_state(n), reward(r) {}

        state_type  state;
        action_type action;
        state_type  next_state;
        double reward;
    };

    template < typename feature_extractor >
    void serialize (const process_sample<feature_extractor>& item, std::ostream& out)
    {
        serialize(item.state, out);
        serialize(item.action, out);
        serialize(item.next_state, out);
        serialize(item.reward, out);
    }

    template < typename feature_extractor >
    void deserialize (process_sample<feature_extractor>& item, std::istream& in)
    {
        deserialize(item.state, in);
        deserialize(item.action, in);
        deserialize(item.next_state, in);
        deserialize(item.reward, in);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename feature_extractor
        >
    class policy
    {
    public:

        typedef feature_extractor feature_extractor_type;
        typedef typename feature_extractor::state_type state_type;
        typedef typename feature_extractor::action_type action_type;


        policy (
        )
        {
            w.set_size(fe.num_features());
            w = 0;
        }

        policy (
            const matrix<double,0,1>& weights_,
            const feature_extractor& fe_
        ) : w(weights_), fe(fe_) {}

        action_type operator() (
            const state_type& state
        ) const
        {
            return fe.find_best_action(state,w);
        }

        const feature_extractor& get_feature_extractor (
        ) const { return fe; }

        const matrix<double,0,1>& get_weights (
        ) const { return w; }


    private:
        matrix<double,0,1> w;
        feature_extractor fe;
    };

    template < typename feature_extractor >
    inline void serialize(const policy<feature_extractor>& item, std::ostream& out)
    {
        int version = 1;
        serialize(version, out);
        serialize(item.get_feature_extractor(), out);
        serialize(item.get_weights(), out);
    }
    template < typename feature_extractor >
    inline void deserialize(policy<feature_extractor>& item, std::istream& in)
    {
        int version = 0;
        deserialize(version, in);
        if (version != 1)
            throw serialization_error("Unexpected version found while deserializing dlib::policy object.");
        feature_extractor fe;
        matrix<double,0,1> w;
        deserialize(fe, in);
        deserialize(w, in);
        item = policy<feature_extractor>(w,fe);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_