File size: 3,608 Bytes
9375c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
// Copyright (C) 2015 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
#include "approximate_linear_models_abstract.h"
#include "../matrix.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
struct process_sample
{
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
process_sample(){}
process_sample(
const state_type& s,
const action_type& a,
const state_type& n,
const double& r
) : state(s), action(a), next_state(n), reward(r) {}
state_type state;
action_type action;
state_type next_state;
double reward;
};
template < typename feature_extractor >
void serialize (const process_sample<feature_extractor>& item, std::ostream& out)
{
serialize(item.state, out);
serialize(item.action, out);
serialize(item.next_state, out);
serialize(item.reward, out);
}
template < typename feature_extractor >
void deserialize (process_sample<feature_extractor>& item, std::istream& in)
{
deserialize(item.state, in);
deserialize(item.action, in);
deserialize(item.next_state, in);
deserialize(item.reward, in);
}
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class policy
{
public:
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
policy (
)
{
w.set_size(fe.num_features());
w = 0;
}
policy (
const matrix<double,0,1>& weights_,
const feature_extractor& fe_
) : w(weights_), fe(fe_) {}
action_type operator() (
const state_type& state
) const
{
return fe.find_best_action(state,w);
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
const matrix<double,0,1>& get_weights (
) const { return w; }
private:
matrix<double,0,1> w;
feature_extractor fe;
};
template < typename feature_extractor >
inline void serialize(const policy<feature_extractor>& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.get_feature_extractor(), out);
serialize(item.get_weights(), out);
}
template < typename feature_extractor >
inline void deserialize(policy<feature_extractor>& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::policy object.");
feature_extractor fe;
matrix<double,0,1> w;
deserialize(fe, in);
deserialize(w, in);
item = policy<feature_extractor>(w,fe);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
|