|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "dnn_semantic_segmentation_ex.h" |
|
|
|
#include <iostream> |
|
#include <dlib/data_io.h> |
|
#include <dlib/gui_widgets.h> |
|
|
|
using namespace std; |
|
using namespace dlib; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const Voc2012class& find_voc2012_class(const uint16_t& index_label) |
|
{ |
|
return find_voc2012_class( |
|
[&index_label](const Voc2012class& voc2012class) |
|
{ |
|
return index_label == voc2012class.index; |
|
} |
|
); |
|
} |
|
|
|
|
|
inline rgb_pixel index_label_to_rgb_label(uint16_t index_label) |
|
{ |
|
return find_voc2012_class(index_label).rgb_label; |
|
} |
|
|
|
|
|
|
|
void index_label_image_to_rgb_label_image( |
|
const matrix<uint16_t>& index_label_image, |
|
matrix<rgb_pixel>& rgb_label_image |
|
) |
|
{ |
|
const long nr = index_label_image.nr(); |
|
const long nc = index_label_image.nc(); |
|
|
|
rgb_label_image.set_size(nr, nc); |
|
|
|
for (long r = 0; r < nr; ++r) |
|
{ |
|
for (long c = 0; c < nc; ++c) |
|
{ |
|
rgb_label_image(r, c) = index_label_to_rgb_label(index_label_image(r, c)); |
|
} |
|
} |
|
} |
|
|
|
|
|
std::string get_most_prominent_non_background_classlabel(const matrix<uint16_t>& index_label_image) |
|
{ |
|
const long nr = index_label_image.nr(); |
|
const long nc = index_label_image.nc(); |
|
|
|
std::vector<unsigned int> counters(class_count); |
|
|
|
for (long r = 0; r < nr; ++r) |
|
{ |
|
for (long c = 0; c < nc; ++c) |
|
{ |
|
const uint16_t label = index_label_image(r, c); |
|
++counters[label]; |
|
} |
|
} |
|
|
|
const auto max_element = std::max_element(counters.begin() + 1, counters.end()); |
|
const uint16_t most_prominent_index_label = max_element - counters.begin(); |
|
|
|
return find_voc2012_class(most_prominent_index_label).classlabel; |
|
} |
|
|
|
|
|
|
|
int main(int argc, char** argv) try |
|
{ |
|
if (argc != 2) |
|
{ |
|
cout << "You call this program like this: " << endl; |
|
cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl; |
|
cout << endl; |
|
cout << "You will also need a trained '" << semantic_segmentation_net_filename << "' file." << endl; |
|
cout << "You can either train it yourself (see example program" << endl; |
|
cout << "dnn_semantic_segmentation_train_ex), or download a" << endl; |
|
cout << "copy from here: http://dlib.net/files/" << semantic_segmentation_net_filename << endl; |
|
return 1; |
|
} |
|
|
|
|
|
anet_type net; |
|
deserialize(semantic_segmentation_net_filename) >> net; |
|
|
|
|
|
image_window win; |
|
|
|
matrix<rgb_pixel> input_image; |
|
matrix<uint16_t> index_label_image; |
|
matrix<rgb_pixel> rgb_label_image; |
|
|
|
|
|
const std::vector<file> files = dlib::get_files_in_directory_tree(argv[1], |
|
dlib::match_endings(".jpeg .jpg .png")); |
|
|
|
cout << "Found " << files.size() << " images, processing..." << endl; |
|
|
|
for (const file& file : files) |
|
{ |
|
|
|
load_image(input_image, file.full_name()); |
|
|
|
|
|
|
|
|
|
const matrix<uint16_t> temp = net(input_image); |
|
|
|
|
|
const chip_details chip_details( |
|
centered_rect(temp.nc() / 2, temp.nr() / 2, input_image.nc(), input_image.nr()), |
|
chip_dims(input_image.nr(), input_image.nc()) |
|
); |
|
extract_image_chip(temp, chip_details, index_label_image, interpolate_nearest_neighbor()); |
|
|
|
|
|
index_label_image_to_rgb_label_image(index_label_image, rgb_label_image); |
|
|
|
|
|
win.set_image(join_rows(input_image, rgb_label_image)); |
|
|
|
|
|
const std::string classlabel = get_most_prominent_non_background_classlabel(index_label_image); |
|
|
|
cout << file.name() << " : " << classlabel << " - hit enter to process the next image"; |
|
cin.get(); |
|
} |
|
} |
|
catch(std::exception& e) |
|
{ |
|
cout << e.what() << endl; |
|
} |
|
|
|
|