|
|
|
|
|
#ifndef DLIB_DNn_VALIDATION_H_ |
|
#define DLIB_DNn_VALIDATION_H_ |
|
|
|
#include "../svm/cross_validate_object_detection_trainer_abstract.h" |
|
#include "../svm/cross_validate_object_detection_trainer.h" |
|
#include "layers.h" |
|
#include <set> |
|
|
|
namespace dlib |
|
{ |
|
namespace impl |
|
{ |
|
inline std::set<std::string> get_labels ( |
|
const std::vector<mmod_rect>& rects1, |
|
const std::vector<mmod_rect>& rects2 |
|
) |
|
{ |
|
std::set<std::string> labels; |
|
for (auto& rr : rects1) |
|
labels.insert(rr.label); |
|
for (auto& rr : rects2) |
|
labels.insert(rr.label); |
|
return labels; |
|
} |
|
} |
|
|
|
template < |
|
typename SUBNET, |
|
typename image_array_type |
|
> |
|
const matrix<double,1,3> test_object_detection_function ( |
|
loss_mmod<SUBNET>& detector, |
|
const image_array_type& images, |
|
const std::vector<std::vector<mmod_rect>>& truth_dets, |
|
const test_box_overlap& overlap_tester = test_box_overlap(), |
|
const double adjust_threshold = 0, |
|
const test_box_overlap& overlaps_ignore_tester = test_box_overlap() |
|
) |
|
{ |
|
|
|
DLIB_CASSERT( is_learning_problem(images,truth_dets) == true , |
|
"\t matrix test_object_detection_function()" |
|
<< "\n\t invalid inputs were given to this function" |
|
<< "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets) |
|
<< "\n\t images.size(): " << images.size() |
|
); |
|
|
|
|
|
|
|
double correct_hits = 0; |
|
double total_true_targets = 0; |
|
|
|
std::vector<std::pair<double,bool> > all_dets; |
|
unsigned long missing_detections = 0; |
|
|
|
resizable_tensor temp; |
|
|
|
for (unsigned long i = 0; i < images.size(); ++i) |
|
{ |
|
std::vector<mmod_rect> hits; |
|
detector.to_tensor(&images[i], &images[i]+1, temp); |
|
detector.subnet().forward(temp); |
|
detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold); |
|
|
|
|
|
for (auto& label : impl::get_labels(truth_dets[i], hits)) |
|
{ |
|
std::vector<full_object_detection> truth_boxes; |
|
std::vector<rectangle> ignore; |
|
std::vector<std::pair<double,rectangle>> boxes; |
|
|
|
for (auto&& b : truth_dets[i]) |
|
{ |
|
if (b.ignore) |
|
{ |
|
ignore.push_back(b); |
|
} |
|
else if (b.label == label) |
|
{ |
|
truth_boxes.push_back(full_object_detection(b.rect)); |
|
++total_true_targets; |
|
} |
|
} |
|
for (auto&& b : hits) |
|
{ |
|
if (b.label == label) |
|
boxes.push_back(std::make_pair(b.detection_confidence, b.rect)); |
|
} |
|
|
|
correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester); |
|
} |
|
} |
|
|
|
std::sort(all_dets.rbegin(), all_dets.rend()); |
|
|
|
double precision, recall; |
|
|
|
double total_hits = all_dets.size(); |
|
|
|
if (total_hits == 0) |
|
precision = 1; |
|
else |
|
precision = correct_hits / total_hits; |
|
|
|
if (total_true_targets == 0) |
|
recall = 1; |
|
else |
|
recall = correct_hits / total_true_targets; |
|
|
|
matrix<double, 1, 3> res; |
|
res = precision, recall, average_precision(all_dets, missing_detections); |
|
return res; |
|
} |
|
|
|
|
|
|
|
} |
|
|
|
#endif |
|
|
|
|