// Copyright (C) 2016  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNn_VALIDATION_H_
#define DLIB_DNn_VALIDATION_H_

#include "../svm/cross_validate_object_detection_trainer_abstract.h"
#include "../svm/cross_validate_object_detection_trainer.h"
#include "layers.h"
#include <set>

namespace dlib
{
    namespace impl
    {
        inline std::set<std::string> get_labels (
            const std::vector<mmod_rect>& rects1,
            const std::vector<mmod_rect>& rects2
        )
        {
            std::set<std::string> labels;
            for (auto& rr : rects1)
                labels.insert(rr.label);
            for (auto& rr : rects2)
                labels.insert(rr.label);
            return labels;
        }
    }

    template <
        typename SUBNET,
        typename image_array_type
        >
    const matrix<double,1,3> test_object_detection_function (
        loss_mmod<SUBNET>& detector,
        const image_array_type& images,
        const std::vector<std::vector<mmod_rect>>& truth_dets,
        const test_box_overlap& overlap_tester = test_box_overlap(),
        const double adjust_threshold = 0,
        const test_box_overlap& overlaps_ignore_tester = test_box_overlap()
    )
    {
        // make sure requires clause is not broken
        DLIB_CASSERT( is_learning_problem(images,truth_dets) == true , 
                    "\t matrix test_object_detection_function()"
                    << "\n\t invalid inputs were given to this function"
                    << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
                    << "\n\t images.size(): " << images.size() 
                    );



        double correct_hits = 0;
        double total_true_targets = 0;

        std::vector<std::pair<double,bool> > all_dets;
        unsigned long missing_detections = 0;

        resizable_tensor temp;

        for (unsigned long i = 0; i < images.size(); ++i)
        {
            std::vector<mmod_rect> hits; 
            detector.to_tensor(&images[i], &images[i]+1, temp);
            detector.subnet().forward(temp);
            detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);


            for (auto& label : impl::get_labels(truth_dets[i], hits))
            {
                std::vector<full_object_detection> truth_boxes;
                std::vector<rectangle> ignore;
                std::vector<std::pair<double,rectangle>> boxes;
                // copy hits and truth_dets into the above three objects
                for (auto&& b : truth_dets[i])
                {
                    if (b.ignore)
                    {
                        ignore.push_back(b);
                    }
                    else if (b.label == label)
                    {
                        truth_boxes.push_back(full_object_detection(b.rect));
                        ++total_true_targets;
                    }
                }
                for (auto&& b : hits)
                {
                    if (b.label == label)
                        boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
                }

                correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester);
            }
        }

        std::sort(all_dets.rbegin(), all_dets.rend());

        double precision, recall;

        double total_hits = all_dets.size();

        if (total_hits == 0)
            precision = 1;
        else
            precision = correct_hits / total_hits;

        if (total_true_targets == 0)
            recall = 1;
        else
            recall = correct_hits / total_true_targets;

        matrix<double, 1, 3> res;
        res = precision, recall, average_precision(all_dets, missing_detections);
        return res;
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_DNn_VALIDATION_H_