// Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_LSPI_Hh_ #define DLIB_LSPI_Hh_ #include "lspi_abstract.h" #include "approximate_linear_models.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class lspi { public: typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; explicit lspi( const feature_extractor& fe_ ) : fe(fe_) { init(); } lspi( ) { init(); } double get_discount ( ) const { return discount; } void set_discount ( double value ) { // make sure requires clause is not broken DLIB_ASSERT(0 < value && value <= 1, "\t void lspi::set_discount(value)" << "\n\t invalid inputs were given to this function" << "\n\t value: " << value ); discount = value; } const feature_extractor& get_feature_extractor ( ) const { return fe; } void be_verbose ( ) { verbose = true; } void be_quiet ( ) { verbose = false; } void set_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void lspi::set_epsilon(eps_)" << "\n\t invalid inputs were given to this function" << "\n\t eps_: " << eps_ ); eps = eps_; } double get_epsilon ( ) const { return eps; } void set_lambda ( double lambda_ ) { // make sure requires clause is not broken DLIB_ASSERT(lambda_ >= 0, "\t void lspi::set_lambda(lambda_)" << "\n\t invalid inputs were given to this function" << "\n\t lambda_: " << lambda_ ); lambda = lambda_; } double get_lambda ( ) const { return lambda; } void set_max_iterations ( unsigned long max_iter ) { max_iterations = max_iter; } unsigned long get_max_iterations ( ) { return max_iterations; } template policy train ( const vector_type& samples ) const { // make sure requires clause is not broken DLIB_ASSERT(samples.size() > 0, "\t policy lspi::train(samples)" << "\n\t invalid inputs were given to this function" ); matrix w(fe.num_features()); w = 0; matrix prev_w, b, f1, f2; matrix A; double change; unsigned long iter = 0; do { A = identity_matrix(fe.num_features())*lambda; b = 0; for (unsigned long i = 0; i < samples.size(); ++i) { fe.get_features(samples[i].state, samples[i].action, f1); fe.get_features(samples[i].next_state, fe.find_best_action(samples[i].next_state,w), f2); A += f1*trans(f1 - discount*f2); b += f1*samples[i].reward; } prev_w = w; if (feature_extractor::force_last_weight_to_1) w = join_cols(pinv(colm(A,range(0,A.nc()-2)))*(b-colm(A,A.nc()-1)),mat(1.0)); else w = pinv(A)*b; change = length(w-prev_w); ++iter; if (verbose) std::cout << "iteration: " << iter << "\tchange: " << change << std::endl; } while(change > eps && iter < max_iterations); return policy(w,fe); } private: void init() { lambda = 0.01; discount = 0.8; eps = 0.01; verbose = false; max_iterations = 100; } double lambda; double discount; double eps; bool verbose; unsigned long max_iterations; feature_extractor fe; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_LSPI_Hh_