In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
import datetime

from matplotlib.pyplot import imshow, imsave
# %matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_sample_image(generator, noise_dim):
 z = torch.randn(100, noise_dim).to(device)
 generated_images = generator(z).view(100, 28, 28)
 result = generated_images.cpu().data.numpy()
 img = np.zeros([280, 280])
 for j in range(10):
 img[j * 28:(j + 1) * 28] = np.concatenate([x for x in result[j * 10:(j + 1) * 10]], axis=-1)
 return img

class Discriminator(nn.Module):
 def __init__(self, input_size=784, num_classes=1):
 super(Discriminator, self).__init__()
 self.layers = nn.Sequential(
 nn.Linear(input_size, 512),
 nn.LeakyReLU(0.2),
 nn.Linear(512, 256),
 nn.LeakyReLU(0.2),
 nn.Linear(256, num_classes),
 nn.Sigmoid(),
 )

 def forward(self, x):
 x = x.view(x.size(0), -1)
 x = self.layers(x)
 return x

class Generator(nn.Module):
 def __init__(self, input_size=100, num_classes=784):
 super(Generator, self).__init__()
 self.layers = nn.Sequential(
 nn.Linear(input_size, 128),
 nn.LeakyReLU(0.2),
 nn.Linear(128, 256),
 nn.BatchNorm1d(256),
 nn.LeakyReLU(0.2),
 nn.Linear(256, 512),
 nn.BatchNorm1d(512),
 nn.LeakyReLU(0.2),
 nn.Linear(512, 1024),
 nn.BatchNorm1d(1024),
 nn.LeakyReLU(0.2),
 nn.Linear(1024, num_classes),
 nn.Tanh()
 )

 def forward(self, x):
 x = self.layers(x)
 x = x.view(x.size(0), 1, 28, 28)
 return x


In [None]:
n_noise = 100

discriminator = Discriminator().to(device)
generator = Generator().to(device)

transform = transforms.Compose([transforms.ToTensor(),
 transforms.Normalize(mean=[0.5],
 std=[0.5])]
)

mnist = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)

batch_size = 64

data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True, drop_last=True)

loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

max_epoch = 50
step = 0
n_critic = 1


In [None]:
d_labels = torch.ones(batch_size, 1).to(device)
d_fakes = torch.zeros(batch_size, 1).to(device)

# Training loop
for epoch in range(max_epoch):
 for idx, (images, _) in enumerate(data_loader):
 real_images = images.to(device)
 real_outputs = discriminator(real_images)
 d_real_loss = loss_fn(real_outputs, d_labels)

 fake_noise = torch.randn(batch_size, n_noise).to(device)
 fake_images = generator(fake_noise)
 fake_outputs = discriminator(fake_images.detach())
 d_fake_loss = loss_fn(fake_outputs, d_fakes)

 d_loss = d_real_loss + d_fake_loss

 discriminator.zero_grad()
 d_loss.backward()
 d_optimizer.step()

 if step % n_critic == 0:
 fake_outputs = discriminator(generator(fake_noise))
 g_loss = loss_fn(fake_outputs, d_labels)

 generator.zero_grad()
 g_loss.backward()
 g_optimizer.step()

 if step % 1000 == 0:
 generator.eval()
 img = get_sample_image(generator, n_noise)
 # imsave('samples/{}_step{}.jpg'.format('gans', str(step).zfill(3)), img, cmap='gray')
 generator.train()
 step += 1


In [None]:
generator.eval()
imshow(get_sample_image(generator, n_noise), cmap='gray')

torch.save(discriminator.state_dict(), 'discriminator.pth')
torch.save(generator.state_dict(), 'generator.pth')
