// Copyright (C) 2016  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNn_UTILITIES_H_
#define DLIB_DNn_UTILITIES_H_

#include "core.h"
#include "utilities_abstract.h"
#include "../geometry.h"
#include <fstream>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    inline void randomize_parameters (
        tensor& params,
        unsigned long num_inputs_and_outputs,
        dlib::rand& rnd
    )
    {
        for (auto& val : params)
        {
            // Draw a random number to initialize the layer according to formula (16)
            // from Understanding the difficulty of training deep feedforward neural
            // networks by Xavier Glorot and Yoshua Bengio.
            val = 2*rnd.get_random_float()-1;
            val *= std::sqrt(6.0/(num_inputs_and_outputs));
        }
    }

// ----------------------------------------------------------------------------------------

    namespace impl
    {
        class visitor_net_to_xml
        {
        public:

            visitor_net_to_xml(std::ostream& out_) : out(out_) {}

            template<typename input_layer_type>
            void operator()(size_t idx, const input_layer_type& l) 
            {
                out << "<layer idx='"<<idx<<"' type='input'>\n";
                to_xml(l,out);
                out << "</layer>\n";
            }

            template <typename T, typename U>
            void operator()(size_t idx, const add_loss_layer<T,U>& l) 
            {
                out << "<layer idx='"<<idx<<"' type='loss'>\n";
                to_xml(l.loss_details(),out);
                out << "</layer>\n";
            }

            template <typename T, typename U, typename E>
            void operator()(size_t idx, const add_layer<T,U,E>& l) 
            {
                out << "<layer idx='"<<idx<<"' type='comp'>\n";
                to_xml(l.layer_details(),out);
                out << "</layer>\n";
            }

            template <unsigned long ID, typename U, typename E>
            void operator()(size_t idx, const add_tag_layer<ID,U,E>& /*l*/) 
            {
                out << "<layer idx='"<<idx<<"' type='tag' id='"<<ID<<"'/>\n";
            }

            template <template<typename> class T, typename U>
            void operator()(size_t idx, const add_skip_layer<T,U>& /*l*/) 
            {
                out << "<layer idx='"<<idx<<"' type='skip' id='"<<(tag_id<T>::id)<<"'/>\n";
            }

        private:

            std::ostream& out;
        };
    }

    template <typename net_type>
    void net_to_xml (
        const net_type& net,
        std::ostream& out
    )
    {
        auto old_precision = out.precision(9);
        out << "<net>\n";
        visit_layers(net, impl::visitor_net_to_xml(out));
        out << "</net>\n";
        // restore the original stream precision.
        out.precision(old_precision);
    }

    template <typename net_type>
    void net_to_xml (
        const net_type& net,
        const std::string& filename
    )
    {
        std::ofstream fout(filename);
        net_to_xml(net, fout);
    }

// ----------------------------------------------------------------------------------------

    namespace impl
    {

        class visitor_net_map_input_to_output
        {
        public:

            visitor_net_map_input_to_output(dpoint& p_) : p(p_) {}

            dpoint& p;

            template<typename input_layer_type>
            void operator()(const input_layer_type& ) 
            {
            }

            template <typename T, typename U>
            void operator()(const add_loss_layer<T,U>& net) 
            {
                (*this)(net.subnet());
            }

            template <typename T, typename U, typename E>
            void operator()(const add_layer<T,U,E>& net) 
            {
                (*this)(net.subnet());
                p = net.layer_details().map_input_to_output(p);
            }
            template <bool B, typename T, typename U, typename E>
            void operator()(const dimpl::subnet_wrapper<add_layer<T,U,E>,B>& net) 
            {
                (*this)(net.subnet());
                p = net.layer_details().map_input_to_output(p);
            }


            template <unsigned long ID, typename U, typename E>
            void operator()(const add_tag_layer<ID,U,E>& net) 
            {
                // tag layers are an identity transform, so do nothing
                (*this)(net.subnet());
            }
            template <bool is_first, unsigned long ID, typename U, typename E>
            void operator()(const dimpl::subnet_wrapper<add_tag_layer<ID,U,E>,is_first>& net) 
            {
                // tag layers are an identity transform, so do nothing
                (*this)(net.subnet());
            }


            template <template<typename> class TAG_TYPE, typename U>
            void operator()(const add_skip_layer<TAG_TYPE,U>& net) 
            {
                (*this)(layer<TAG_TYPE>(net));
            }
            template <bool is_first, template<typename> class TAG_TYPE, typename SUBNET>
            void operator()(const dimpl::subnet_wrapper<add_skip_layer<TAG_TYPE,SUBNET>,is_first>& net) 
            {
                // skip layers are an identity transform, so do nothing
                (*this)(layer<TAG_TYPE>(net));
            }

        };

        class visitor_net_map_output_to_input
        {
        public:
            visitor_net_map_output_to_input(dpoint& p_) : p(p_) {}

            dpoint& p;

            template<typename input_layer_type>
            void operator()(const input_layer_type& ) 
            {
            }

            template <typename T, typename U>
            void operator()(const add_loss_layer<T,U>& net) 
            {
                (*this)(net.subnet());
            }

            template <typename T, typename U, typename E>
            void operator()(const add_layer<T,U,E>& net) 
            {
                p = net.layer_details().map_output_to_input(p);
                (*this)(net.subnet());
            }
            template <bool B, typename T, typename U, typename E>
            void operator()(const dimpl::subnet_wrapper<add_layer<T,U,E>,B>& net) 
            {
                p = net.layer_details().map_output_to_input(p);
                (*this)(net.subnet());
            }


            template <unsigned long ID, typename U, typename E>
            void operator()(const add_tag_layer<ID,U,E>& net) 
            {
                // tag layers are an identity transform, so do nothing
                (*this)(net.subnet());
            }
            template <bool is_first, unsigned long ID, typename U, typename E>
            void operator()(const dimpl::subnet_wrapper<add_tag_layer<ID,U,E>,is_first>& net) 
            {
                // tag layers are an identity transform, so do nothing
                (*this)(net.subnet());
            }


            template <template<typename> class TAG_TYPE, typename U>
            void operator()(const add_skip_layer<TAG_TYPE,U>& net) 
            {
                (*this)(layer<TAG_TYPE>(net));
            }
            template <bool is_first, template<typename> class TAG_TYPE, typename SUBNET>
            void operator()(const dimpl::subnet_wrapper<add_skip_layer<TAG_TYPE,SUBNET>,is_first>& net) 
            {
                // skip layers are an identity transform, so do nothing
                (*this)(layer<TAG_TYPE>(net));
            }

        };
    }

    template <typename net_type>
    inline dpoint input_tensor_to_output_tensor(
        const net_type& net,
        dpoint p 
    )
    {
        impl::visitor_net_map_input_to_output temp(p);
        temp(net);
        return p;
    }

    template <typename net_type>
    inline dpoint output_tensor_to_input_tensor(
        const net_type& net,
        dpoint p  
    )
    {
        impl::visitor_net_map_output_to_input temp(p);
        temp(net);
        return p;
    }

// ----------------------------------------------------------------------------------------

    template <typename net_type>
    size_t count_parameters(
        const net_type& net
    )
    {
        size_t num_parameters = 0;
        visit_layer_parameters(net, [&](const tensor& t) { num_parameters += t.size(); });
        return num_parameters;
    }

// ----------------------------------------------------------------------------------------

    namespace impl
    {
        class visitor_learning_rate_multiplier
        {
        public:
            visitor_learning_rate_multiplier(double new_learning_rate_multiplier_) :
                new_learning_rate_multiplier(new_learning_rate_multiplier_) {}

            template <typename layer>
            void operator()(layer& l) const
            {
                set_learning_rate_multiplier(l, new_learning_rate_multiplier);
            }

        private:

            double new_learning_rate_multiplier;
        };
    }

    template <typename net_type>
    void set_all_learning_rate_multipliers(
        net_type& net,
        double learning_rate_multiplier
    )
    {
        DLIB_CASSERT(learning_rate_multiplier >= 0);
        impl::visitor_learning_rate_multiplier temp(learning_rate_multiplier);
        visit_computational_layers(net, temp);
    }

    template <size_t begin, size_t end, typename net_type>
    void set_learning_rate_multipliers_range(
        net_type& net,
        double learning_rate_multiplier
    )
    {
        static_assert(begin <= end, "Invalid range");
        static_assert(end <= net_type::num_layers, "Invalid range");
        DLIB_CASSERT(learning_rate_multiplier >= 0);
        impl::visitor_learning_rate_multiplier temp(learning_rate_multiplier);
        visit_computational_layers_range<begin, end>(net, temp);
    }

// ----------------------------------------------------------------------------------------
}

#endif // DLIB_DNn_UTILITIES_H_