{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "from torchvision.utils import save_image\n", "\n", "import numpy as np\n", "import datetime\n", "\n", "from matplotlib.pyplot import imshow, imsave\n", "# %matplotlib inline\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_sample_image(generator, noise_dim):\n", " z = torch.randn(100, noise_dim).to(device)\n", " generated_images = generator(z).view(100, 28, 28)\n", " result = generated_images.cpu().data.numpy()\n", " img = np.zeros([280, 280])\n", " for j in range(10):\n", " img[j * 28:(j + 1) * 28] = np.concatenate([x for x in result[j * 10:(j + 1) * 10]], axis=-1)\n", " return img\n", "\n", "class Discriminator(nn.Module):\n", " def __init__(self, input_size=784, num_classes=1):\n", " super(Discriminator, self).__init__()\n", " self.layers = nn.Sequential(\n", " nn.Linear(input_size, 512),\n", " nn.LeakyReLU(0.2),\n", " nn.Linear(512, 256),\n", " nn.LeakyReLU(0.2),\n", " nn.Linear(256, num_classes),\n", " nn.Sigmoid(),\n", " )\n", "\n", " def forward(self, x):\n", " x = x.view(x.size(0), -1)\n", " x = self.layers(x)\n", " return x\n", "\n", "class Generator(nn.Module):\n", " def __init__(self, input_size=100, num_classes=784):\n", " super(Generator, self).__init__()\n", " self.layers = nn.Sequential(\n", " nn.Linear(input_size, 128),\n", " nn.LeakyReLU(0.2),\n", " nn.Linear(128, 256),\n", " nn.BatchNorm1d(256),\n", " nn.LeakyReLU(0.2),\n", " nn.Linear(256, 512),\n", " nn.BatchNorm1d(512),\n", " nn.LeakyReLU(0.2),\n", " nn.Linear(512, 1024),\n", " nn.BatchNorm1d(1024),\n", " nn.LeakyReLU(0.2),\n", " nn.Linear(1024, num_classes),\n", " nn.Tanh()\n", " )\n", "\n", " def forward(self, x):\n", " x = self.layers(x)\n", " x = x.view(x.size(0), 1, 28, 28)\n", " return x\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_noise = 100\n", "\n", "discriminator = Discriminator().to(device)\n", "generator = Generator().to(device)\n", "\n", "transform = transforms.Compose([transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.5],\n", " std=[0.5])]\n", ")\n", "\n", "mnist = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)\n", "\n", "batch_size = 64\n", "\n", "data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True, drop_last=True)\n", "\n", "loss_fn = nn.BCELoss()\n", "d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n", "g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n", "\n", "max_epoch = 50\n", "step = 0\n", "n_critic = 1\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "d_labels = torch.ones(batch_size, 1).to(device)\n", "d_fakes = torch.zeros(batch_size, 1).to(device)\n", "\n", "# Training loop\n", "for epoch in range(max_epoch):\n", " for idx, (images, _) in enumerate(data_loader):\n", " real_images = images.to(device)\n", " real_outputs = discriminator(real_images)\n", " d_real_loss = loss_fn(real_outputs, d_labels)\n", "\n", " fake_noise = torch.randn(batch_size, n_noise).to(device)\n", " fake_images = generator(fake_noise)\n", " fake_outputs = discriminator(fake_images.detach())\n", " d_fake_loss = loss_fn(fake_outputs, d_fakes)\n", "\n", " d_loss = d_real_loss + d_fake_loss\n", "\n", " discriminator.zero_grad()\n", " d_loss.backward()\n", " d_optimizer.step()\n", "\n", " if step % n_critic == 0:\n", " fake_outputs = discriminator(generator(fake_noise))\n", " g_loss = loss_fn(fake_outputs, d_labels)\n", "\n", " generator.zero_grad()\n", " g_loss.backward()\n", " g_optimizer.step()\n", "\n", " if step % 1000 == 0:\n", " generator.eval()\n", " img = get_sample_image(generator, n_noise)\n", " # imsave('samples/{}_step{}.jpg'.format('gans', str(step).zfill(3)), img, cmap='gray')\n", " generator.train()\n", " step += 1\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "generator.eval()\n", "imshow(get_sample_image(generator, n_noise), cmap='gray')\n", "\n", "torch.save(discriminator.state_dict(), 'discriminator.pth')\n", "torch.save(generator.state_dict(), 'generator.pth')\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }