// Copyright (C) 2020 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CIFAR_CPp_ #define DLIB_CIFAR_CPp_ #include "cifar.h" #include // ---------------------------------------------------------------------------------------- namespace dlib { namespace impl { void load_cifar_10_batch ( const std::string& folder_name, const std::string& batch_name, const size_t first_idx, const size_t images_per_batch, std::vector>& images, std::vector& labels ) { std::ifstream fin(folder_name + "/" + batch_name, std::ios::binary); if (!fin) throw error("Unable to open file " + batch_name); const long nr = 32; const long nc = 32; const long plane_size = nr * nc; const long image_size = 3 * plane_size; for (size_t i = 0; i < images_per_batch; ++i) { char l; fin.read(&l, 1); labels[first_idx + i] = l; images[first_idx + i].set_size(nr, nc); std::array buffer; fin.read((char*)(buffer.data()), buffer.size()); for (long k = 0; k < plane_size; ++k) { char r = buffer[0 * plane_size + k]; char g = buffer[1 * plane_size + k]; char b = buffer[2 * plane_size + k]; const long row = k / nr; const long col = k % nr; images[first_idx + i](row, col) = rgb_pixel(r, g, b); } } if (!fin) throw error("Unable to read file " + batch_name); if (fin.get() != EOF) throw error("Unexpected bytes at end of " + batch_name); } } void load_cifar_10_dataset ( const std::string& folder_name, std::vector>& training_images, std::vector& training_labels, std::vector>& testing_images, std::vector& testing_labels ) { using namespace std; const size_t images_per_batch = 10000; const size_t num_training_batches = 5; const size_t num_testing_batches = 1; training_images.resize(images_per_batch * num_training_batches); training_labels.resize(images_per_batch * num_training_batches); testing_images.resize(images_per_batch * num_testing_batches); testing_labels.resize(images_per_batch * num_testing_batches); std::vector training_batches_names{ "data_batch_1.bin", "data_batch_2.bin", "data_batch_3.bin", "data_batch_4.bin", "data_batch_5.bin", }; for (size_t i = 0; i < num_training_batches; ++i) { impl::load_cifar_10_batch( folder_name, training_batches_names[i], i * images_per_batch, images_per_batch, training_images, training_labels); } impl::load_cifar_10_batch( folder_name, "test_batch.bin", 0, images_per_batch, testing_images, testing_labels); } } // ---------------------------------------------------------------------------------------- #endif // DLIB_CIFAR_CPp_