AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
raw
history blame
3.6 kB
// Copyright (C) 2020 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CIFAR_CPp_
#define DLIB_CIFAR_CPp_
#include "cifar.h"
#include <fstream>
// ----------------------------------------------------------------------------------------
namespace dlib
{
namespace impl
{
void load_cifar_10_batch (
const std::string& folder_name,
const std::string& batch_name,
const size_t first_idx,
const size_t images_per_batch,
std::vector<matrix<rgb_pixel>>& images,
std::vector<unsigned long>& labels
)
{
std::ifstream fin(folder_name + "/" + batch_name, std::ios::binary);
if (!fin) throw error("Unable to open file " + batch_name);
const long nr = 32;
const long nc = 32;
const long plane_size = nr * nc;
const long image_size = 3 * plane_size;
for (size_t i = 0; i < images_per_batch; ++i)
{
char l;
fin.read(&l, 1);
labels[first_idx + i] = l;
images[first_idx + i].set_size(nr, nc);
std::array<unsigned char, image_size> buffer;
fin.read((char*)(buffer.data()), buffer.size());
for (long k = 0; k < plane_size; ++k)
{
char r = buffer[0 * plane_size + k];
char g = buffer[1 * plane_size + k];
char b = buffer[2 * plane_size + k];
const long row = k / nr;
const long col = k % nr;
images[first_idx + i](row, col) = rgb_pixel(r, g, b);
}
}
if (!fin) throw error("Unable to read file " + batch_name);
if (fin.get() != EOF) throw error("Unexpected bytes at end of " + batch_name);
}
}
void load_cifar_10_dataset (
const std::string& folder_name,
std::vector<matrix<rgb_pixel>>& training_images,
std::vector<unsigned long>& training_labels,
std::vector<matrix<rgb_pixel>>& testing_images,
std::vector<unsigned long>& testing_labels
)
{
using namespace std;
const size_t images_per_batch = 10000;
const size_t num_training_batches = 5;
const size_t num_testing_batches = 1;
training_images.resize(images_per_batch * num_training_batches);
training_labels.resize(images_per_batch * num_training_batches);
testing_images.resize(images_per_batch * num_testing_batches);
testing_labels.resize(images_per_batch * num_testing_batches);
std::vector<string> training_batches_names{
"data_batch_1.bin",
"data_batch_2.bin",
"data_batch_3.bin",
"data_batch_4.bin",
"data_batch_5.bin",
};
for (size_t i = 0; i < num_training_batches; ++i)
{
impl::load_cifar_10_batch(
folder_name,
training_batches_names[i],
i * images_per_batch,
images_per_batch,
training_images,
training_labels);
}
impl::load_cifar_10_batch(
folder_name,
"test_batch.bin",
0,
images_per_batch,
testing_images,
testing_labels);
}
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_CIFAR_CPp_