Jihene commited on
Commit
6c6a596
·
1 Parent(s): d99b2ad

Create cgan.py

Browse files
Files changed (1) hide show
  1. cgan.py +206 -0
cgan.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "original code: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cgan/cgan.py"
2
+
3
+ import argparse
4
+ import os
5
+ import numpy as np
6
+ import math
7
+
8
+ import torchvision.transforms as transforms
9
+ from torchvision.utils import save_image
10
+
11
+ from torch.utils.data import DataLoader
12
+ from torchvision import datasets
13
+ from torch.autograd import Variable
14
+
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch
18
+
19
+
20
+
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--n_epochs", type=int, default=10, help="number of epochs of training")
23
+ parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
24
+ parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
25
+ parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
26
+ parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
27
+ parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
28
+ parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
29
+ parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
30
+ parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
31
+ parser.add_argument("--channels", type=int, default=1, help="number of image channels")
32
+ parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
33
+ opt = parser.parse_args(args=[])
34
+ print(opt)
35
+
36
+ img_shape = (opt.channels, opt.img_size, opt.img_size)
37
+
38
+ cuda = True if torch.cuda.is_available() else False
39
+
40
+
41
+ class Generator(nn.Module):
42
+ def __init__(self):
43
+ super(Generator, self).__init__()
44
+
45
+ self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
46
+
47
+ def block(in_feat, out_feat, normalize=True):
48
+ layers = [nn.Linear(in_feat, out_feat)]
49
+ if normalize:
50
+ layers.append(nn.BatchNorm1d(out_feat, 0.8))
51
+ layers.append(nn.LeakyReLU(0.2, inplace=True))
52
+ return layers
53
+
54
+ self.model = nn.Sequential(
55
+ *block(opt.latent_dim + opt.n_classes, 128, normalize=False),
56
+ *block(128, 256),
57
+ *block(256, 512),
58
+ *block(512, 1024),
59
+ nn.Linear(1024, int(np.prod(img_shape))),
60
+ nn.Tanh()
61
+ )
62
+
63
+ def forward(self, noise, labels):
64
+ # Concatenate label embedding and image to produce input
65
+ gen_input = torch.cat((self.label_emb(labels), noise), -1)
66
+ img = self.model(gen_input)
67
+ img = img.view(img.size(0), *img_shape)
68
+ return img
69
+
70
+
71
+ class Discriminator(nn.Module):
72
+ def __init__(self):
73
+ super(Discriminator, self).__init__()
74
+
75
+ self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
76
+
77
+ self.model = nn.Sequential(
78
+ nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
79
+ nn.LeakyReLU(0.2, inplace=True),
80
+ nn.Linear(512, 512),
81
+ nn.Dropout(0.4),
82
+ nn.LeakyReLU(0.2, inplace=True),
83
+ nn.Linear(512, 512),
84
+ nn.Dropout(0.4),
85
+ nn.LeakyReLU(0.2, inplace=True),
86
+ nn.Linear(512, 1),
87
+ )
88
+
89
+ def forward(self, img, labels):
90
+ # Concatenate label embedding and image to produce input
91
+ d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
92
+ validity = self.model(d_in)
93
+ return validity
94
+
95
+
96
+ # Loss functions
97
+ adversarial_loss = torch.nn.MSELoss()
98
+
99
+ # Initialize generator and discriminator
100
+ generator = Generator()
101
+ discriminator = Discriminator()
102
+
103
+ if cuda:
104
+ generator.cuda()
105
+ discriminator.cuda()
106
+ adversarial_loss.cuda()
107
+
108
+ # Configure data loader
109
+ os.makedirs("../../data/mnist", exist_ok=True)
110
+ dataloader = torch.utils.data.DataLoader(
111
+ datasets.MNIST(
112
+ "../../data/mnist",
113
+ train=True,
114
+ download=True,
115
+ transform=transforms.Compose(
116
+ [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
117
+ ),
118
+ ),
119
+ batch_size=opt.batch_size,
120
+ shuffle=True,
121
+ )
122
+
123
+ # Optimizers
124
+ optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
125
+ optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
126
+
127
+ FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
128
+ LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
129
+
130
+ os.makedirs("images", exist_ok=True)
131
+ def sample_image(n_row, batches_done):
132
+ """Saves a grid of generated digits ranging from 0 to n_classes"""
133
+ # Sample noise
134
+ z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
135
+ # Get labels ranging from 0 to n_classes for n rows
136
+ labels = np.array([num for _ in range(n_row) for num in range(n_row)])
137
+ labels = Variable(LongTensor(labels))
138
+ gen_imgs = generator(z, labels)
139
+ save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
140
+
141
+
142
+ # ----------
143
+ # Training
144
+ # ----------
145
+
146
+ for epoch in range(opt.n_epochs):
147
+ for i, (imgs, labels) in enumerate(dataloader):
148
+
149
+ batch_size = imgs.shape[0]
150
+
151
+ # Adversarial ground truths
152
+ valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
153
+ fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
154
+
155
+ # Configure input
156
+ real_imgs = Variable(imgs.type(FloatTensor))
157
+ labels = Variable(labels.type(LongTensor))
158
+
159
+ # -----------------
160
+ # Train Generator
161
+ # -----------------
162
+
163
+ optimizer_G.zero_grad()
164
+
165
+ # Sample noise and labels as generator input
166
+ z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
167
+ gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
168
+
169
+ # Generate a batch of images
170
+ gen_imgs = generator(z, gen_labels)
171
+
172
+ # Loss measures generator's ability to fool the discriminator
173
+ validity = discriminator(gen_imgs, gen_labels)
174
+ g_loss = adversarial_loss(validity, valid)
175
+
176
+ g_loss.backward()
177
+ optimizer_G.step()
178
+
179
+ # ---------------------
180
+ # Train Discriminator
181
+ # ---------------------
182
+
183
+ optimizer_D.zero_grad()
184
+
185
+ # Loss for real images
186
+ validity_real = discriminator(real_imgs, labels)
187
+ d_real_loss = adversarial_loss(validity_real, valid)
188
+
189
+ # Loss for fake images
190
+ validity_fake = discriminator(gen_imgs.detach(), gen_labels)
191
+ d_fake_loss = adversarial_loss(validity_fake, fake)
192
+
193
+ # Total discriminator loss
194
+ d_loss = (d_real_loss + d_fake_loss) / 2
195
+
196
+ d_loss.backward()
197
+ optimizer_D.step()
198
+
199
+ print(
200
+ "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
201
+ % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
202
+ )
203
+
204
+ batches_done = epoch * len(dataloader) + i
205
+ if batches_done % opt.sample_interval == 0:
206
+ sample_image(n_row=10, batches_done=batches_done)