#include <dlib/dnn.h> |
#include <dlib/image_io.h> |
#include <dlib/misc_api.h> |
using namespace dlib; |
using namespace std; |
std::vector<std::vector<string>> load_objects_list ( |
const string& dir |
) |
{ |
std::vector<std::vector<string>> objects; |
for (auto subdir : directory(dir).get_dirs()) |
{ |
std::vector<string> imgs; |
for (auto img : subdir.get_files()) |
imgs.push_back(img); |
if (imgs.size() != 0) |
objects.push_back(imgs); |
} |
return objects; |
} |
void load_mini_batch ( |
const size_t num_people, |
const size_t samples_per_id, |
dlib::rand& rnd, |
const std::vector<std::vector<string>>& objs, |
std::vector<matrix<rgb_pixel>>& images, |
std::vector<unsigned long>& labels |
) |
{ |
images.clear(); |
labels.clear(); |
DLIB_CASSERT(num_people <= objs.size(), "The dataset doesn't have that many people in it."); |
std::vector<bool> already_selected(objs.size(), false); |
matrix<rgb_pixel> image; |
for (size_t i = 0; i < num_people; ++i) |
{ |
size_t id = rnd.get_random_32bit_number()%objs.size(); |
while(already_selected[id]) |
id = rnd.get_random_32bit_number()%objs.size(); |
already_selected[id] = true; |
for (size_t j = 0; j < samples_per_id; ++j) |
{ |
const auto& obj = objs[id][rnd.get_random_32bit_number()%objs[id].size()]; |
load_image(image, obj); |
images.push_back(std::move(image)); |
labels.push_back(id); |
} |
} |
for (auto&& crop : images) |
{ |
disturb_colors(crop,rnd); |
if (rnd.get_random_double() > 0.1) |
crop = jitter_image(crop,rnd); |
} |
DLIB_CASSERT(images.size() > 0); |
for (auto&& img : images) |
{ |
DLIB_CASSERT(img.nr() == images[0].nr() && img.nc() == images[0].nc(), |
"All the images in a single mini-batch must be the same size."); |
} |
} |
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET> |
using residual = add_prev1<block<N,BN,1,tag1<SUBNET>>>; |
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET> |
using residual_down = add_prev2<avg_pool<2,2,2,2,skip1<tag2<block<N,BN,2,tag1<SUBNET>>>>>>; |
template <int N, template <typename> class BN, int stride, typename SUBNET> |
using block = BN<con<N,3,3,1,1,relu<BN<con<N,3,3,stride,stride,SUBNET>>>>>; |
template <int N, typename SUBNET> using res = relu<residual<block,N,bn_con,SUBNET>>; |
template <int N, typename SUBNET> using ares = relu<residual<block,N,affine,SUBNET>>; |
template <int N, typename SUBNET> using res_down = relu<residual_down<block,N,bn_con,SUBNET>>; |
template <int N, typename SUBNET> using ares_down = relu<residual_down<block,N,affine,SUBNET>>; |
template <typename SUBNET> using level0 = res_down<256,SUBNET>; |
template <typename SUBNET> using level1 = res<256,res<256,res_down<256,SUBNET>>>; |
template <typename SUBNET> using level2 = res<128,res<128,res_down<128,SUBNET>>>; |
template <typename SUBNET> using level3 = res<64,res<64,res<64,res_down<64,SUBNET>>>>; |
template <typename SUBNET> using level4 = res<32,res<32,res<32,SUBNET>>>; |
template <typename SUBNET> using alevel0 = ares_down<256,SUBNET>; |
template <typename SUBNET> using alevel1 = ares<256,ares<256,ares_down<256,SUBNET>>>; |
template <typename SUBNET> using alevel2 = ares<128,ares<128,ares_down<128,SUBNET>>>; |
template <typename SUBNET> using alevel3 = ares<64,ares<64,ares<64,ares_down<64,SUBNET>>>>; |
template <typename SUBNET> using alevel4 = ares<32,ares<32,ares<32,SUBNET>>>; |
using net_type = loss_metric<fc_no_bias<128,avg_pool_everything< |
level0< |
level1< |
level2< |
level3< |
level4< |
max_pool<3,3,2,2,relu<bn_con<con<32,7,7,2,2, |
input_rgb_image |
>>>>>>>>>>>>; |
using anet_type = loss_metric<fc_no_bias<128,avg_pool_everything< |
alevel0< |
alevel1< |
alevel2< |
alevel3< |
alevel4< |
max_pool<3,3,2,2,relu<affine<con<32,7,7,2,2, |
input_rgb_image |
>>>>>>>>>>>>; |
int main(int argc, char** argv) |
{ |
if (argc != 2) |
{ |
cout << "Give a folder as input. It should contain sub-folders of images and we will " << endl; |
cout << "learn to distinguish between these sub-folders with metric learning. " << endl; |
cout << "For example, you can run this program on the very small examples/johns dataset" << endl; |
cout << "that comes with dlib by running this command:" << endl; |
cout << " ./dnn_metric_learning_on_images_ex johns" << endl; |
return 1; |
} |
auto objs = load_objects_list(argv[1]); |
cout << "objs.size(): "<< objs.size() << endl; |
std::vector<matrix<rgb_pixel>> images; |
std::vector<unsigned long> labels; |
net_type net; |
dnn_trainer<net_type> trainer(net, sgd(0.0001, 0.9)); |
trainer.set_learning_rate(0.1); |
trainer.be_verbose(); |
trainer.set_synchronization_file("face_metric_sync", std::chrono::minutes(5)); |
trainer.set_iterations_without_progress_threshold(300); |
dlib::pipe<std::vector<matrix<rgb_pixel>>> qimages(4); |
dlib::pipe<std::vector<unsigned long>> qlabels(4); |
auto data_loader = [&qimages, &qlabels, &objs](time_t seed) |
{ |
dlib::rand rnd(time(0)+seed); |
std::vector<matrix<rgb_pixel>> images; |
std::vector<unsigned long> labels; |
while(qimages.is_enabled()) |
{ |
try |
{ |
load_mini_batch(5, 5, rnd, objs, images, labels); |
qimages.enqueue(images); |
qlabels.enqueue(labels); |
} |
catch(std::exception& e) |
{ |
cout << "EXCEPTION IN LOADING DATA" << endl; |
cout << e.what() << endl; |
} |
} |
}; |
std::thread data_loader1([data_loader](){ data_loader(1); }); |
std::thread data_loader2([data_loader](){ data_loader(2); }); |
std::thread data_loader3([data_loader](){ data_loader(3); }); |
std::thread data_loader4([data_loader](){ data_loader(4); }); |
std::thread data_loader5([data_loader](){ data_loader(5); }); |
while(trainer.get_learning_rate() >= 1e-4) |
{ |
qimages.dequeue(images); |
qlabels.dequeue(labels); |
trainer.train_one_step(images, labels); |
} |
trainer.get_net(); |
cout << "done training" << endl; |
net.clean(); |
serialize("metric_network_renset.dat") << net; |
qimages.disable(); |
qlabels.disable(); |
data_loader1.join(); |
data_loader2.join(); |
data_loader3.join(); |
data_loader4.join(); |
data_loader5.join(); |
dlib::rand rnd(time(0)); |
load_mini_batch(5, 5, rnd, objs, images, labels); |
anet_type testing_net = net; |
std::vector<matrix<float,0,1>> embedded = testing_net(images); |
int num_right = 0; |
int num_wrong = 0; |
for (size_t i = 0; i < embedded.size(); ++i) |
{ |
for (size_t j = i+1; j < embedded.size(); ++j) |
{ |
if (labels[i] == labels[j]) |
{ |
if (length(embedded[i]-embedded[j]) < testing_net.loss_details().get_distance_threshold()) |
++num_right; |
else |
++num_wrong; |
} |
else |
{ |
if (length(embedded[i]-embedded[j]) >= testing_net.loss_details().get_distance_threshold()) |
++num_right; |
else |
++num_wrong; |
} |
} |
} |
cout << "num_right: "<< num_right << endl; |
cout << "num_wrong: "<< num_wrong << endl; |
} |