AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
raw
history blame
4.09 kB
// Copyright (C) 2016 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_VALIDATION_H_
#define DLIB_DNn_VALIDATION_H_
#include "../svm/cross_validate_object_detection_trainer_abstract.h"
#include "../svm/cross_validate_object_detection_trainer.h"
#include "layers.h"
#include <set>
namespace dlib
{
namespace impl
{
inline std::set<std::string> get_labels (
const std::vector<mmod_rect>& rects1,
const std::vector<mmod_rect>& rects2
)
{
std::set<std::string> labels;
for (auto& rr : rects1)
labels.insert(rr.label);
for (auto& rr : rects2)
labels.insert(rr.label);
return labels;
}
}
template <
typename SUBNET,
typename image_array_type
>
const matrix<double,1,3> test_object_detection_function (
loss_mmod<SUBNET>& detector,
const image_array_type& images,
const std::vector<std::vector<mmod_rect>>& truth_dets,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0,
const test_box_overlap& overlaps_ignore_tester = test_box_overlap()
)
{
// make sure requires clause is not broken
DLIB_CASSERT( is_learning_problem(images,truth_dets) == true ,
"\t matrix test_object_detection_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
<< "\n\t images.size(): " << images.size()
);
double correct_hits = 0;
double total_true_targets = 0;
std::vector<std::pair<double,bool> > all_dets;
unsigned long missing_detections = 0;
resizable_tensor temp;
for (unsigned long i = 0; i < images.size(); ++i)
{
std::vector<mmod_rect> hits;
detector.to_tensor(&images[i], &images[i]+1, temp);
detector.subnet().forward(temp);
detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);
for (auto& label : impl::get_labels(truth_dets[i], hits))
{
std::vector<full_object_detection> truth_boxes;
std::vector<rectangle> ignore;
std::vector<std::pair<double,rectangle>> boxes;
// copy hits and truth_dets into the above three objects
for (auto&& b : truth_dets[i])
{
if (b.ignore)
{
ignore.push_back(b);
}
else if (b.label == label)
{
truth_boxes.push_back(full_object_detection(b.rect));
++total_true_targets;
}
}
for (auto&& b : hits)
{
if (b.label == label)
boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
}
correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester);
}
}
std::sort(all_dets.rbegin(), all_dets.rend());
double precision, recall;
double total_hits = all_dets.size();
if (total_hits == 0)
precision = 1;
else
precision = correct_hits / total_hits;
if (total_true_targets == 0)
recall = 1;
else
recall = correct_hits / total_true_targets;
matrix<double, 1, 3> res;
res = precision, recall, average_precision(all_dets, missing_detections);
return res;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_VALIDATION_H_