// Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ #define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ #include "approximate_linear_models_abstract.h" #include "../matrix.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > struct process_sample { typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; process_sample(){} process_sample( const state_type& s, const action_type& a, const state_type& n, const double& r ) : state(s), action(a), next_state(n), reward(r) {} state_type state; action_type action; state_type next_state; double reward; }; template < typename feature_extractor > void serialize (const process_sample& item, std::ostream& out) { serialize(item.state, out); serialize(item.action, out); serialize(item.next_state, out); serialize(item.reward, out); } template < typename feature_extractor > void deserialize (process_sample& item, std::istream& in) { deserialize(item.state, in); deserialize(item.action, in); deserialize(item.next_state, in); deserialize(item.reward, in); } // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class policy { public: typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; policy ( ) { w.set_size(fe.num_features()); w = 0; } policy ( const matrix& weights_, const feature_extractor& fe_ ) : w(weights_), fe(fe_) {} action_type operator() ( const state_type& state ) const { return fe.find_best_action(state,w); } const feature_extractor& get_feature_extractor ( ) const { return fe; } const matrix& get_weights ( ) const { return w; } private: matrix w; feature_extractor fe; }; template < typename feature_extractor > inline void serialize(const policy& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.get_feature_extractor(), out); serialize(item.get_weights(), out); } template < typename feature_extractor > inline void deserialize(policy& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::policy object."); feature_extractor fe; matrix w; deserialize(fe, in); deserialize(w, in); item = policy(w,fe); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_