|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <algorithm> |
|
#include <iostream> |
|
|
|
#include <dlib/data_io.h> |
|
#include <dlib/dnn.h> |
|
#include <dlib/gui_widgets.h> |
|
#include <dlib/matrix.h> |
|
|
|
using namespace std; |
|
using namespace dlib; |
|
|
|
|
|
const size_t noise_size = 100; |
|
using noise_t = std::array<matrix<float, 1, 1>, noise_size>; |
|
|
|
noise_t make_noise(dlib::rand& rnd) |
|
{ |
|
noise_t noise; |
|
for (auto& n : noise) |
|
{ |
|
n = rnd.get_random_gaussian(); |
|
} |
|
return noise; |
|
} |
|
|
|
|
|
template<long num_filters, long kernel_size, int stride, int padding, typename SUBNET> |
|
using conp = add_layer<con_<num_filters, kernel_size, kernel_size, stride, stride, padding, padding>, SUBNET>; |
|
|
|
|
|
template<long num_filters, long kernel_size, int stride, int padding, typename SUBNET> |
|
using contp = add_layer<cont_<num_filters, kernel_size, kernel_size, stride, stride, padding, padding>, SUBNET>; |
|
|
|
|
|
|
|
|
|
|
|
using generator_type = |
|
loss_binary_log_per_pixel< |
|
sig<contp<1, 4, 2, 1, |
|
relu<bn_con<contp<64, 4, 2, 1, |
|
relu<bn_con<contp<128, 3, 2, 1, |
|
relu<bn_con<contp<256, 4, 1, 0, |
|
input<noise_t> |
|
>>>>>>>>>>>>; |
|
|
|
|
|
|
|
using discriminator_type = |
|
loss_binary_log< |
|
conp<1, 3, 1, 0, |
|
leaky_relu<bn_con<conp<256, 4, 2, 1, |
|
leaky_relu<bn_con<conp<128, 4, 2, 1, |
|
leaky_relu<conp<64, 4, 2, 1, |
|
input<matrix<unsigned char>> |
|
>>>>>>>>>>; |
|
|
|
|
|
matrix<unsigned char> generate_image(generator_type& net, const noise_t& noise) |
|
{ |
|
const matrix<float> output = net(noise); |
|
matrix<unsigned char> image; |
|
assign_image(image, 255 * output); |
|
return image; |
|
} |
|
|
|
std::vector<matrix<unsigned char>> get_generated_images(const tensor& out) |
|
{ |
|
std::vector<matrix<unsigned char>> images; |
|
for (long n = 0; n < out.num_samples(); ++n) |
|
{ |
|
matrix<float> output = image_plane(out, n); |
|
matrix<unsigned char> image; |
|
assign_image(image, 255 * output); |
|
images.push_back(std::move(image)); |
|
} |
|
return images; |
|
} |
|
|
|
int main(int argc, char** argv) try |
|
{ |
|
|
|
if (argc != 2) |
|
{ |
|
cout << "This example needs the MNIST dataset to run!" << endl; |
|
cout << "You can get MNIST from http://yann.lecun.com/exdb/mnist/" << endl; |
|
cout << "Download the 4 files that comprise the dataset, decompress them, and" << endl; |
|
cout << "put them in a folder. Then give that folder as input to this program." << endl; |
|
return EXIT_FAILURE; |
|
} |
|
|
|
|
|
|
|
|
|
std::vector<matrix<unsigned char>> training_images; |
|
std::vector<unsigned long> training_labels; |
|
std::vector<matrix<unsigned char>> testing_images; |
|
std::vector<unsigned long> testing_labels; |
|
load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels); |
|
|
|
|
|
srand(1234); |
|
dlib::rand rnd(std::rand()); |
|
|
|
|
|
generator_type generator; |
|
discriminator_type discriminator; |
|
|
|
visit_computational_layers(discriminator, [](leaky_relu_& l){ l = leaky_relu_(0.2); }); |
|
|
|
disable_duplicative_biases(generator); |
|
disable_duplicative_biases(discriminator); |
|
|
|
discriminator(generate_image(generator, make_noise(rnd))); |
|
cout << "generator (" << count_parameters(generator) << " parameters)" << endl; |
|
cout << generator << endl; |
|
cout << "discriminator (" << count_parameters(discriminator) << " parameters)" << endl; |
|
cout << discriminator << endl; |
|
|
|
|
|
|
|
|
|
|
|
std::vector<adam> g_solvers(generator.num_computational_layers, adam(0, 0.5, 0.999)); |
|
std::vector<adam> d_solvers(discriminator.num_computational_layers, adam(0, 0.5, 0.999)); |
|
double learning_rate = 2e-4; |
|
|
|
|
|
size_t iteration = 0; |
|
if (file_exists("dcgan_sync")) |
|
{ |
|
deserialize("dcgan_sync") >> generator >> discriminator >> iteration; |
|
} |
|
|
|
const size_t minibatch_size = 64; |
|
const std::vector<float> real_labels(minibatch_size, 1); |
|
const std::vector<float> fake_labels(minibatch_size, -1); |
|
dlib::image_window win; |
|
resizable_tensor real_samples_tensor, fake_samples_tensor, noises_tensor; |
|
running_stats<double> g_loss, d_loss; |
|
while (iteration < 50000) |
|
{ |
|
|
|
std::vector<matrix<unsigned char>> real_samples; |
|
while (real_samples.size() < minibatch_size) |
|
{ |
|
auto idx = rnd.get_random_32bit_number() % training_images.size(); |
|
real_samples.push_back(training_images[idx]); |
|
} |
|
|
|
discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor); |
|
d_loss.add(discriminator.compute_loss(real_samples_tensor, real_labels.begin())); |
|
discriminator.back_propagate_error(real_samples_tensor); |
|
discriminator.update_parameters(d_solvers, learning_rate); |
|
|
|
|
|
|
|
std::vector<noise_t> noises; |
|
while (noises.size() < minibatch_size) |
|
{ |
|
noises.push_back(make_noise(rnd)); |
|
} |
|
|
|
generator.to_tensor(noises.begin(), noises.end(), noises_tensor); |
|
|
|
const auto fake_samples = get_generated_images(generator.forward(noises_tensor)); |
|
|
|
|
|
discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor); |
|
d_loss.add(discriminator.compute_loss(fake_samples_tensor, fake_labels.begin())); |
|
discriminator.back_propagate_error(fake_samples_tensor); |
|
discriminator.update_parameters(d_solvers, learning_rate); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
g_loss.add(discriminator.compute_loss(fake_samples_tensor, real_labels.begin())); |
|
|
|
discriminator.back_propagate_error(fake_samples_tensor); |
|
|
|
const tensor& d_grad = discriminator.get_final_data_gradient(); |
|
generator.back_propagate_error(noises_tensor, d_grad); |
|
generator.update_parameters(g_solvers, learning_rate); |
|
|
|
|
|
|
|
if (++iteration % 1000 == 0) |
|
{ |
|
serialize("dcgan_sync") << generator << discriminator << iteration; |
|
std::cout << |
|
"step#: " << iteration << |
|
"\tdiscriminator loss: " << d_loss.mean() * 2 << |
|
"\tgenerator loss: " << g_loss.mean() << '\n'; |
|
win.set_image(tile_images(fake_samples)); |
|
win.set_title("DCGAN step#: " + to_string(iteration)); |
|
d_loss.clear(); |
|
g_loss.clear(); |
|
} |
|
} |
|
|
|
|
|
|
|
generator.clean(); |
|
serialize("dcgan_mnist.dnn") << generator; |
|
|
|
|
|
|
|
while (!win.is_closed()) |
|
{ |
|
const auto image = generate_image(generator, make_noise(rnd)); |
|
const auto real = discriminator(image) > 0; |
|
win.set_image(image); |
|
cout << "The discriminator thinks it's " << (real ? "real" : "fake"); |
|
cout << ". Hit enter to generate a new image"; |
|
cin.get(); |
|
} |
|
|
|
return EXIT_SUCCESS; |
|
} |
|
catch(exception& e) |
|
{ |
|
cout << e.what() << endl; |
|
return EXIT_FAILURE; |
|
} |
|
|