Create cgan.py
Browse files
cgan.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"original code: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cgan/cgan.py"
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from torchvision import datasets
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument("--n_epochs", type=int, default=10, help="number of epochs of training")
|
23 |
+
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
|
24 |
+
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
|
25 |
+
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
|
26 |
+
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
|
27 |
+
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
|
28 |
+
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
|
29 |
+
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
|
30 |
+
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
|
31 |
+
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
|
32 |
+
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
|
33 |
+
opt = parser.parse_args(args=[])
|
34 |
+
print(opt)
|
35 |
+
|
36 |
+
img_shape = (opt.channels, opt.img_size, opt.img_size)
|
37 |
+
|
38 |
+
cuda = True if torch.cuda.is_available() else False
|
39 |
+
|
40 |
+
|
41 |
+
class Generator(nn.Module):
|
42 |
+
def __init__(self):
|
43 |
+
super(Generator, self).__init__()
|
44 |
+
|
45 |
+
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
|
46 |
+
|
47 |
+
def block(in_feat, out_feat, normalize=True):
|
48 |
+
layers = [nn.Linear(in_feat, out_feat)]
|
49 |
+
if normalize:
|
50 |
+
layers.append(nn.BatchNorm1d(out_feat, 0.8))
|
51 |
+
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
52 |
+
return layers
|
53 |
+
|
54 |
+
self.model = nn.Sequential(
|
55 |
+
*block(opt.latent_dim + opt.n_classes, 128, normalize=False),
|
56 |
+
*block(128, 256),
|
57 |
+
*block(256, 512),
|
58 |
+
*block(512, 1024),
|
59 |
+
nn.Linear(1024, int(np.prod(img_shape))),
|
60 |
+
nn.Tanh()
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, noise, labels):
|
64 |
+
# Concatenate label embedding and image to produce input
|
65 |
+
gen_input = torch.cat((self.label_emb(labels), noise), -1)
|
66 |
+
img = self.model(gen_input)
|
67 |
+
img = img.view(img.size(0), *img_shape)
|
68 |
+
return img
|
69 |
+
|
70 |
+
|
71 |
+
class Discriminator(nn.Module):
|
72 |
+
def __init__(self):
|
73 |
+
super(Discriminator, self).__init__()
|
74 |
+
|
75 |
+
self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
|
76 |
+
|
77 |
+
self.model = nn.Sequential(
|
78 |
+
nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
|
79 |
+
nn.LeakyReLU(0.2, inplace=True),
|
80 |
+
nn.Linear(512, 512),
|
81 |
+
nn.Dropout(0.4),
|
82 |
+
nn.LeakyReLU(0.2, inplace=True),
|
83 |
+
nn.Linear(512, 512),
|
84 |
+
nn.Dropout(0.4),
|
85 |
+
nn.LeakyReLU(0.2, inplace=True),
|
86 |
+
nn.Linear(512, 1),
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(self, img, labels):
|
90 |
+
# Concatenate label embedding and image to produce input
|
91 |
+
d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
|
92 |
+
validity = self.model(d_in)
|
93 |
+
return validity
|
94 |
+
|
95 |
+
|
96 |
+
# Loss functions
|
97 |
+
adversarial_loss = torch.nn.MSELoss()
|
98 |
+
|
99 |
+
# Initialize generator and discriminator
|
100 |
+
generator = Generator()
|
101 |
+
discriminator = Discriminator()
|
102 |
+
|
103 |
+
if cuda:
|
104 |
+
generator.cuda()
|
105 |
+
discriminator.cuda()
|
106 |
+
adversarial_loss.cuda()
|
107 |
+
|
108 |
+
# Configure data loader
|
109 |
+
os.makedirs("../../data/mnist", exist_ok=True)
|
110 |
+
dataloader = torch.utils.data.DataLoader(
|
111 |
+
datasets.MNIST(
|
112 |
+
"../../data/mnist",
|
113 |
+
train=True,
|
114 |
+
download=True,
|
115 |
+
transform=transforms.Compose(
|
116 |
+
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
|
117 |
+
),
|
118 |
+
),
|
119 |
+
batch_size=opt.batch_size,
|
120 |
+
shuffle=True,
|
121 |
+
)
|
122 |
+
|
123 |
+
# Optimizers
|
124 |
+
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
125 |
+
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
126 |
+
|
127 |
+
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
|
128 |
+
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
|
129 |
+
|
130 |
+
os.makedirs("images", exist_ok=True)
|
131 |
+
def sample_image(n_row, batches_done):
|
132 |
+
"""Saves a grid of generated digits ranging from 0 to n_classes"""
|
133 |
+
# Sample noise
|
134 |
+
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
|
135 |
+
# Get labels ranging from 0 to n_classes for n rows
|
136 |
+
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
|
137 |
+
labels = Variable(LongTensor(labels))
|
138 |
+
gen_imgs = generator(z, labels)
|
139 |
+
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
|
140 |
+
|
141 |
+
|
142 |
+
# ----------
|
143 |
+
# Training
|
144 |
+
# ----------
|
145 |
+
|
146 |
+
for epoch in range(opt.n_epochs):
|
147 |
+
for i, (imgs, labels) in enumerate(dataloader):
|
148 |
+
|
149 |
+
batch_size = imgs.shape[0]
|
150 |
+
|
151 |
+
# Adversarial ground truths
|
152 |
+
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
|
153 |
+
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
|
154 |
+
|
155 |
+
# Configure input
|
156 |
+
real_imgs = Variable(imgs.type(FloatTensor))
|
157 |
+
labels = Variable(labels.type(LongTensor))
|
158 |
+
|
159 |
+
# -----------------
|
160 |
+
# Train Generator
|
161 |
+
# -----------------
|
162 |
+
|
163 |
+
optimizer_G.zero_grad()
|
164 |
+
|
165 |
+
# Sample noise and labels as generator input
|
166 |
+
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
|
167 |
+
gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
|
168 |
+
|
169 |
+
# Generate a batch of images
|
170 |
+
gen_imgs = generator(z, gen_labels)
|
171 |
+
|
172 |
+
# Loss measures generator's ability to fool the discriminator
|
173 |
+
validity = discriminator(gen_imgs, gen_labels)
|
174 |
+
g_loss = adversarial_loss(validity, valid)
|
175 |
+
|
176 |
+
g_loss.backward()
|
177 |
+
optimizer_G.step()
|
178 |
+
|
179 |
+
# ---------------------
|
180 |
+
# Train Discriminator
|
181 |
+
# ---------------------
|
182 |
+
|
183 |
+
optimizer_D.zero_grad()
|
184 |
+
|
185 |
+
# Loss for real images
|
186 |
+
validity_real = discriminator(real_imgs, labels)
|
187 |
+
d_real_loss = adversarial_loss(validity_real, valid)
|
188 |
+
|
189 |
+
# Loss for fake images
|
190 |
+
validity_fake = discriminator(gen_imgs.detach(), gen_labels)
|
191 |
+
d_fake_loss = adversarial_loss(validity_fake, fake)
|
192 |
+
|
193 |
+
# Total discriminator loss
|
194 |
+
d_loss = (d_real_loss + d_fake_loss) / 2
|
195 |
+
|
196 |
+
d_loss.backward()
|
197 |
+
optimizer_D.step()
|
198 |
+
|
199 |
+
print(
|
200 |
+
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
|
201 |
+
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
|
202 |
+
)
|
203 |
+
|
204 |
+
batches_done = epoch * len(dataloader) + i
|
205 |
+
if batches_done % opt.sample_interval == 0:
|
206 |
+
sample_image(n_row=10, batches_done=batches_done)
|