// Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNn_SOLVERS_H_ #define DLIB_DNn_SOLVERS_H_ #include "solvers_abstract.h" #include "../cuda/tensor.h" #include #include "layers.h" namespace dlib { class sgd { public: explicit sgd( float weight_decay_, float momentum_ = 0.9 ) { weight_decay = weight_decay_; momentum = momentum_; } sgd( ) : sgd(0.0005, 0.9) { } float get_momentum ( ) const { return momentum; } float get_weight_decay ( ) const { return weight_decay; } template const tensor& operator() ( const float learning_rate, const layer_type& l, const tensor& params_grad ) { const tensor& params = l.get_layer_params(); DLIB_CASSERT(params.size() != 0); if (v.size() == 0) { v.copy_size(params_grad); v = 0; } const double lr = learning_rate*get_learning_rate_multiplier(l); const double wd = weight_decay*get_weight_decay_multiplier(l); //perform: v = momentum*mat(v) - wd*lr*mat(params) - lr*mat(params_grad); tt::affine_transform(v, v, params, params_grad, momentum, -wd*lr, -lr); return v; } template const tensor& operator() ( const float learning_rate, const fc_& l, const tensor& params_grad ) { update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.get_num_outputs()); return v; } template < long _num_filters, long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y, int _padding_x > const tensor& operator() ( const float learning_rate, const con_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l, const tensor& params_grad ) { update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.num_filters()); return v; } template < long _num_filters, long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y, int _padding_x > const tensor& operator() ( const float learning_rate, const cont_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l, const tensor& params_grad ) { update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.num_filters()); return v; } template < layer_mode mode > const tensor& operator() ( const float learning_rate, const bn_& l, const tensor& params_grad ) { update_considering_bias(learning_rate, l, params_grad, params_grad.size()/2); return v; } friend void serialize(const sgd& item, std::ostream& out) { serialize("sgd2", out); serialize(item.v, out); serialize(item.weight_decay, out); serialize(item.momentum, out); } friend void deserialize(sgd& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "sgd2") throw serialization_error("Unexpected version found while deserializing dlib::sgd."); deserialize(item.v, in); deserialize(item.weight_decay, in); deserialize(item.momentum, in); } friend std::ostream& operator<< (std::ostream& out, const sgd& item) { out << "sgd: weight_decay="<