File size: 5,396 Bytes
9375c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright (C) 2015  Davis E. King ([email protected])
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_MNIST_CPp_
#define DLIB_MNIST_CPp_

#include "mnist.h"
#include <fstream>
#include "../byte_orderer.h"
#include "../uintn.h"

// ----------------------------------------------------------------------------------------

namespace dlib
{
    void load_mnist_dataset (
        const std::string& folder_name,
        std::vector<matrix<unsigned char> >& training_images,
        std::vector<unsigned long>& training_labels,
        std::vector<matrix<unsigned char> >& testing_images,
        std::vector<unsigned long>& testing_labels
    )
    {
        using namespace std;
        ifstream fin1((folder_name+"/train-images-idx3-ubyte").c_str(), ios::binary);
        if (!fin1)
        {
            fin1.open((folder_name + "/train-images.idx3-ubyte").c_str(), ios::binary);
        }

        ifstream fin2((folder_name+"/train-labels-idx1-ubyte").c_str(), ios::binary);
        if (!fin2)
        {
            fin2.open((folder_name + "/train-labels.idx1-ubyte").c_str(), ios::binary);
        }

        ifstream fin3((folder_name+"/t10k-images-idx3-ubyte").c_str(), ios::binary);
        if (!fin3)
        {
            fin3.open((folder_name + "/t10k-images.idx3-ubyte").c_str(), ios::binary);
        }

        ifstream fin4((folder_name+"/t10k-labels-idx1-ubyte").c_str(), ios::binary);
        if (!fin4)
        {
            fin4.open((folder_name + "/t10k-labels.idx1-ubyte").c_str(), ios::binary);
        }

        if (!fin1) throw error("Unable to open file train-images-idx3-ubyte or train-images.idx3-ubyte");
        if (!fin2) throw error("Unable to open file train-labels-idx1-ubyte or train-labels.idx1-ubyte");
        if (!fin3) throw error("Unable to open file t10k-images-idx3-ubyte or t10k-images.idx3-ubyte");
        if (!fin4) throw error("Unable to open file t10k-labels-idx1-ubyte or t10k-labels.idx1-ubyte");
		
        byte_orderer bo;

        // make sure the files have the contents we expect.
        uint32 magic, num, nr, nc, num2, num3, num4;
        fin1.read((char*)&magic, sizeof(magic));  bo.big_to_host(magic);
        fin1.read((char*)&num, sizeof(num));  bo.big_to_host(num);
        fin1.read((char*)&nr, sizeof(nr));  bo.big_to_host(nr);
        fin1.read((char*)&nc, sizeof(nc));  bo.big_to_host(nc);
        if (magic != 2051 || num != 60000 || nr != 28 || nc != 28)
            throw error("mnist dat files are corrupted.");

        fin2.read((char*)&magic, sizeof(magic));  bo.big_to_host(magic);
        fin2.read((char*)&num2, sizeof(num2));  bo.big_to_host(num2);
        if (magic != 2049 || num2 != 60000)
            throw error("mnist dat files are corrupted.");

        fin3.read((char*)&magic, sizeof(magic));  bo.big_to_host(magic);
        fin3.read((char*)&num3, sizeof(num3));  bo.big_to_host(num3);
        fin3.read((char*)&nr, sizeof(nr));  bo.big_to_host(nr);
        fin3.read((char*)&nc, sizeof(nc));  bo.big_to_host(nc);
        if (magic != 2051 || num3 != 10000 || nr != 28 || nc != 28)
            throw error("mnist dat files are corrupted.");

        fin4.read((char*)&magic, sizeof(magic));  bo.big_to_host(magic);
        fin4.read((char*)&num4, sizeof(num4));  bo.big_to_host(num4);
        if (magic != 2049 || num4 != 10000)
            throw error("mnist dat files are corrupted.");

        if (!fin1) throw error("Unable to read train-images-idx3-ubyte");
        if (!fin2) throw error("Unable to read train-labels-idx1-ubyte");
        if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte");
        if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte");


        training_images.resize(60000);
        training_labels.resize(60000);
        testing_images.resize(10000);
        testing_labels.resize(10000);

        for (size_t i = 0; i < training_images.size(); ++i)
        {
            training_images[i].set_size(nr,nc);
            fin1.read((char*)&training_images[i](0,0), nr*nc);
        }
        for (size_t i = 0; i < training_labels.size(); ++i)
        {
            char l;
            fin2.read(&l, 1);
            training_labels[i] = l;
        }

        for (size_t i = 0; i < testing_images.size(); ++i)
        {
            testing_images[i].set_size(nr,nc);
            fin3.read((char*)&testing_images[i](0,0), nr*nc);
        }
        for (size_t i = 0; i < testing_labels.size(); ++i)
        {
            char l;
            fin4.read(&l, 1);
            testing_labels[i] = l;
        }

        if (!fin1) throw error("Unable to read train-images-idx3-ubyte");
        if (!fin2) throw error("Unable to read train-labels-idx1-ubyte");
        if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte");
        if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte");

        if (fin1.get() != EOF) throw error("Unexpected bytes at end of train-images-idx3-ubyte");
        if (fin2.get() != EOF) throw error("Unexpected bytes at end of train-labels-idx1-ubyte");
        if (fin3.get() != EOF) throw error("Unexpected bytes at end of t10k-images-idx3-ubyte");
        if (fin4.get() != EOF) throw error("Unexpected bytes at end of t10k-labels-idx1-ubyte");
    }
}

// ----------------------------------------------------------------------------------------

#endif // DLIB_MNIST_CPp_