Truptidand commited on
Commit
f13cdf6
·
1 Parent(s): 8f3c6f4

Delete GAN.py

Browse files
Files changed (1) hide show
  1. GAN.py +0 -870
GAN.py DELETED
@@ -1,870 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- # In[1]:
5
-
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import matplotlib.pyplot as plt
10
- import seaborn as sns
11
- import os
12
- import pickle
13
- import time
14
- import random
15
-
16
-
17
- # In[8]:
18
-
19
-
20
- import PIL
21
- from PIL import Image
22
- import keras.backend as K
23
- import tensorflow as tf
24
- from tensorflow import keras
25
- from keras.optimizers import Adam
26
- from keras.models import Sequential
27
- from keras import layers,Model,Input
28
- from keras.layers import Lambda,Reshape,UpSampling2D,ReLU,add,ZeroPadding2D
29
- from keras.layers import Activation,BatchNormalization,Concatenate,concatenate
30
- from keras.layers import Dense,Conv2D,Flatten,Dropout,LeakyReLU
31
- from keras.preprocessing.image import ImageDataGenerator
32
-
33
-
34
- # ### Conditioning Augmentation Network
35
-
36
- # In[3]:
37
-
38
-
39
- # conditioned by the text.
40
- def conditioning_augmentation(x):
41
- """The mean_logsigma passed as argument is converted into the text conditioning variable.
42
-
43
- Args:
44
- x: The output of the text embedding passed through a FC layer with LeakyReLU non-linearity.
45
-
46
- Returns:
47
- c: The text conditioning variable after computation.
48
- """
49
- mean = x[:, :128]
50
- log_sigma = x[:, 128:]
51
-
52
- stddev = tf.math.exp(log_sigma)
53
- epsilon = K.random_normal(shape=K.constant((mean.shape[1], ), dtype='int32'))
54
- c = mean + stddev * epsilon
55
- return c
56
-
57
- def build_ca_network():
58
- """Builds the conditioning augmentation network.
59
- """
60
- input_layer1 = Input(shape=(1024,)) #size of the vocabulary in the text data
61
- mls = Dense(256)(input_layer1)
62
- mls = LeakyReLU(alpha=0.2)(mls)
63
- ca = Lambda(conditioning_augmentation)(mls)
64
- return Model(inputs=[input_layer1], outputs=[ca])
65
-
66
-
67
- # ### Stage 1 Generator Network
68
-
69
- # In[4]:
70
-
71
-
72
- def UpSamplingBlock(x, num_kernels):
73
- """An Upsample block with Upsampling2D, Conv2D, BatchNormalization and a ReLU activation.
74
-
75
- Args:
76
- x: The preceding layer as input.
77
- num_kernels: Number of kernels for the Conv2D layer.
78
-
79
- Returns:
80
- x: The final activation layer after the Upsampling block.
81
- """
82
- x = UpSampling2D(size=(2,2))(x)
83
- x = Conv2D(num_kernels, kernel_size=(3,3), padding='same', strides=1, use_bias=False,
84
- kernel_initializer='he_uniform')(x)
85
- x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x) #prevent from mode collapse
86
- x = ReLU()(x)
87
- return x
88
-
89
-
90
- def build_stage1_generator():
91
-
92
- input_layer1 = Input(shape=(1024,))
93
- ca = Dense(256)(input_layer1)
94
- ca = LeakyReLU(alpha=0.2)(ca)
95
-
96
- # Obtain the conditioned text
97
- c = Lambda(conditioning_augmentation)(ca)
98
-
99
- input_layer2 = Input(shape=(100,))
100
- concat = Concatenate(axis=1)([c, input_layer2])
101
-
102
- x = Dense(16384, use_bias=False)(concat)
103
- x = ReLU()(x)
104
- x = Reshape((4, 4, 1024), input_shape=(16384,))(x)
105
-
106
- x = UpSamplingBlock(x, 512)
107
- x = UpSamplingBlock(x, 256)
108
- x = UpSamplingBlock(x, 128)
109
- x = UpSamplingBlock(x, 64) # upsampled our image to 64*64*3
110
-
111
- x = Conv2D(3, kernel_size=3, padding='same', strides=1, use_bias=False,
112
- kernel_initializer='he_uniform')(x)
113
- x = Activation('tanh')(x)
114
-
115
- stage1_gen = Model(inputs=[input_layer1, input_layer2], outputs=[x, ca])
116
- return stage1_gen
117
-
118
-
119
- # In[5]:
120
-
121
-
122
- generator = build_stage1_generator()
123
- generator.summary()
124
-
125
-
126
- # ### Stage 1 Discriminator Network
127
-
128
- # In[9]:
129
-
130
-
131
- def ConvBlock(x, num_kernels, kernel_size=(4,4), strides=2, activation=True):
132
- """A ConvBlock with a Conv2D, BatchNormalization and LeakyReLU activation.
133
-
134
- Args:
135
- x: The preceding layer as input.
136
- num_kernels: Number of kernels for the Conv2D layer.
137
-
138
- Returns:
139
- x: The final activation layer after the ConvBlock block.
140
- """
141
- x = Conv2D(num_kernels, kernel_size=kernel_size, padding='same', strides=strides, use_bias=False,
142
- kernel_initializer='he_uniform')(x)
143
- x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
144
-
145
- if activation:
146
- x = LeakyReLU(alpha=0.2)(x)
147
- return x
148
-
149
-
150
- def build_embedding_compressor():
151
- """Build embedding compressor model
152
- """
153
- input_layer1 = Input(shape=(1024,))
154
- x = Dense(128)(input_layer1)
155
- x = ReLU()(x)
156
-
157
- model = Model(inputs=[input_layer1], outputs=[x])
158
- return model
159
-
160
- # the discriminator is fed with two inputs, the feature from Generator and the text embedding
161
- def build_stage1_discriminator():
162
- """Builds the Stage 1 Discriminator that uses the 64x64 resolution images from the generator
163
- and the compressed and spatially replicated embedding.
164
-
165
- Returns:
166
- Stage 1 Discriminator Model for StackGAN.
167
- """
168
- input_layer1 = Input(shape=(64, 64, 3))
169
-
170
- x = Conv2D(64, kernel_size=(4,4), strides=2, padding='same', use_bias=False,
171
- kernel_initializer='he_uniform')(input_layer1)
172
- x = LeakyReLU(alpha=0.2)(x)
173
-
174
- x = ConvBlock(x, 128)
175
- x = ConvBlock(x, 256)
176
- x = ConvBlock(x, 512)
177
-
178
- # Obtain the compressed and spatially replicated text embedding
179
- input_layer2 = Input(shape=(4, 4, 128)) #2nd input to discriminator, text embedding
180
- concat = concatenate([x, input_layer2])
181
-
182
- x1 = Conv2D(512, kernel_size=(1,1), padding='same', strides=1, use_bias=False,
183
- kernel_initializer='he_uniform')(concat)
184
- x1 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
185
- x1 = LeakyReLU(alpha=0.2)(x)
186
-
187
- # Flatten and add a FC layer to predict.
188
- x1 = Flatten()(x1)
189
- x1 = Dense(1)(x1)
190
- x1 = Activation('sigmoid')(x1)
191
-
192
- stage1_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x1])
193
- return stage1_dis
194
-
195
-
196
- # In[10]:
197
-
198
-
199
- discriminator = build_stage1_discriminator()
200
- discriminator.summary()
201
-
202
-
203
- # ### Stage 1 Adversarial Model (Building a GAN)
204
-
205
- # In[11]:
206
-
207
-
208
- # Building GAN with Generator and Discriminator
209
-
210
- def build_adversarial(generator_model, discriminator_model):
211
- """Stage 1 Adversarial model.
212
-
213
- Args:
214
- generator_model: Stage 1 Generator Model
215
- discriminator_model: Stage 1 Discriminator Model
216
-
217
- Returns:
218
- Adversarial Model.
219
- """
220
- input_layer1 = Input(shape=(1024,))
221
- input_layer2 = Input(shape=(100,))
222
- input_layer3 = Input(shape=(4, 4, 128))
223
-
224
- x, ca = generator_model([input_layer1, input_layer2]) #text,noise
225
-
226
- discriminator_model.trainable = False
227
-
228
- probabilities = discriminator_model([x, input_layer3])
229
- adversarial_model = Model(inputs=[input_layer1, input_layer2, input_layer3], outputs=[probabilities, ca])
230
- return adversarial_model
231
-
232
-
233
- # In[12]:
234
-
235
-
236
- ganstage1 = build_adversarial(generator, discriminator)
237
- ganstage1.summary()
238
-
239
-
240
- # ### Train Utilities
241
-
242
- # In[13]:
243
-
244
-
245
- def checkpoint_prefix():
246
- checkpoint_dir = './training_checkpoints'
247
- checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
248
-
249
- return checkpoint_prefix
250
-
251
- def adversarial_loss(y_true, y_pred):
252
- mean = y_pred[:, :128]
253
- ls = y_pred[:, 128:]
254
- loss = -ls + 0.5 * (-1 + tf.math.exp(2.0 * ls) + tf.math.square(mean))
255
- loss = K.mean(loss)
256
- return loss
257
-
258
- def normalize(input_image, real_image):
259
- input_image = (input_image / 127.5) - 1
260
- real_image = (real_image / 127.5) - 1
261
-
262
- return input_image, real_image
263
-
264
- def load_class_ids_filenames(class_id_path, filename_path):
265
- with open(class_id_path, 'rb') as file:
266
- class_id = pickle.load(file, encoding='latin1')
267
-
268
- with open(filename_path, 'rb') as file:
269
- filename = pickle.load(file, encoding='latin1')
270
-
271
- return class_id, filename
272
-
273
- def load_text_embeddings(text_embeddings):
274
- with open(text_embeddings, 'rb') as file:
275
- embeds = pickle.load(file, encoding='latin1')
276
- embeds = np.array(embeds)
277
-
278
- return embeds
279
-
280
- def load_bbox(data_path):
281
- bbox_path = data_path + '/bounding_boxes.txt'
282
- image_path = data_path + '/images.txt'
283
- bbox_df = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)
284
- filename_df = pd.read_csv(image_path, delim_whitespace=True, header=None)
285
-
286
- filenames = filename_df[1].tolist()
287
- bbox_dict = {i[:-4]:[] for i in filenames[:2]}
288
-
289
- for i in range(0, len(filenames)):
290
- bbox = bbox_df.iloc[i][1:].tolist()
291
- dict_key = filenames[i][:-4]
292
- bbox_dict[dict_key] = bbox
293
-
294
- return bbox_dict
295
-
296
- def load_images(image_path, bounding_box, size):
297
- """Crops the image to the bounding box and then resizes it.
298
- """
299
- image = Image.open(image_path).convert('RGB')
300
- w, h = image.size
301
- if bounding_box is not None:
302
- r = int(np.maximum(bounding_box[2], bounding_box[3]) * 0.75)
303
- c_x = int((bounding_box[0] + bounding_box[2]) / 2)
304
- c_y = int((bounding_box[1] + bounding_box[3]) / 2)
305
- y1 = np.maximum(0, c_y - r)
306
- y2 = np.minimum(h, c_y + r)
307
- x1 = np.maximum(0, c_x - r)
308
- x2 = np.minimum(w, c_x + r)
309
- image = image.crop([x1, y1, x2, y2])
310
-
311
- image = image.resize(size, PIL.Image.BILINEAR)
312
- return image
313
-
314
- def load_data(filename_path, class_id_path, dataset_path, embeddings_path, size):
315
- """Loads the Dataset.
316
- """
317
- data_dir = "D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds"
318
- train_dir = data_dir + "/train"
319
- test_dir = data_dir + "/test"
320
- embeddings_path_train = train_dir + "/char-CNN-RNN-embeddings.pickle"
321
- embeddings_path_test = test_dir + "/char-CNN-RNN-embeddings.pickle"
322
- filename_path_train = train_dir + "/filenames.pickle"
323
- filename_path_test = test_dir + "/filenames.pickle"
324
- class_id_path_train = train_dir + "/class_info.pickle"
325
- class_id_path_test = test_dir + "/class_info.pickle"
326
- dataset_path = "D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011"
327
- class_id, filenames = load_class_ids_filenames(class_id_path, filename_path)
328
- embeddings = load_text_embeddings(embeddings_path)
329
- bbox_dict = load_bbox(dataset_path)
330
-
331
- x, y, embeds = [], [], []
332
-
333
- for i, filename in enumerate(filenames):
334
- bbox = bbox_dict[filename]
335
-
336
- try:
337
- image_path = f'{dataset_path}/images/{filename}.jpg'
338
- image = load_images(image_path, bbox, size)
339
- e = embeddings[i, :, :]
340
- embed_index = np.random.randint(0, e.shape[0] - 1)
341
- embed = e[embed_index, :]
342
-
343
- x.append(np.array(image))
344
- y.append(class_id[i])
345
- embeds.append(embed)
346
-
347
- except Exception as e:
348
- print(f'{e}')
349
-
350
- x = np.array(x)
351
- y = np.array(y)
352
- embeds = np.array(embeds)
353
-
354
- return x, y, embeds
355
-
356
- def save_image(file, save_path):
357
- """Saves the image at the specified file path.
358
- """
359
- image = plt.figure()
360
- ax = image.add_subplot(1,1,1)
361
- ax.imshow(file)
362
- ax.axis("off")
363
- plt.savefig(save_path)
364
-
365
-
366
- # In[28]:
367
-
368
-
369
- ############################################################
370
- # StackGAN class
371
- ############################################################
372
-
373
- class StackGanStage1(object):
374
- """StackGAN Stage 1 class."""
375
-
376
- data_dir = "D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds"
377
- train_dir = data_dir + "/train"
378
- test_dir = data_dir + "/test"
379
- embeddings_path_train = train_dir + "/char-CNN-RNN-embeddings.pickle"
380
- embeddings_path_test = test_dir + "/char-CNN-RNN-embeddings.pickle"
381
- filename_path_train = train_dir + "/filenames.pickle"
382
- filename_path_test = test_dir + "/filenames.pickle"
383
- class_id_path_train = train_dir + "/class_info.pickle"
384
- class_id_path_test = test_dir + "/class_info.pickle"
385
- dataset_path = "D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011"
386
- def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage1_generator_lr=0.0002, stage1_discriminator_lr=0.0002):
387
- self.epochs = epochs
388
- self.z_dim = z_dim
389
- self.enable_function = enable_function
390
- self.stage1_generator_lr = stage1_generator_lr
391
- self.stage1_discriminator_lr = stage1_discriminator_lr
392
- self.image_size = 64
393
- self.conditioning_dim = 128
394
- self.batch_size = batch_size
395
-
396
- self.stage1_generator_optimizer = Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)
397
- self.stage1_discriminator_optimizer = Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
398
-
399
- self.stage1_generator = build_stage1_generator()
400
- self.stage1_generator.compile(loss='mse', optimizer=self.stage1_generator_optimizer)
401
-
402
- self.stage1_discriminator = build_stage1_discriminator()
403
- self.stage1_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage1_discriminator_optimizer)
404
-
405
- self.ca_network = build_ca_network()
406
- self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')
407
-
408
- self.embedding_compressor = build_embedding_compressor()
409
- self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')
410
-
411
- self.stage1_adversarial = build_adversarial(self.stage1_generator, self.stage1_discriminator)
412
- self.stage1_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage1_generator_optimizer)
413
-
414
- self.checkpoint1 = tf.train.Checkpoint(
415
- generator_optimizer=self.stage1_generator_optimizer,
416
- discriminator_optimizer=self.stage1_discriminator_optimizer,
417
- generator=self.stage1_generator,
418
- discriminator=self.stage1_discriminator)
419
-
420
- def visualize_stage1(self):
421
- """Running Tensorboard visualizations.
422
- """
423
- tb = TensorBoard(log_dir="logs/".format(time.time()))
424
- tb.set_model(self.stage1_generator)
425
- tb.set_model(self.stage1_discriminator)
426
- tb.set_model(self.ca_network)
427
- tb.set_model(self.embedding_compressor)
428
-
429
- def train_stage1(self):
430
- """Trains the stage1 StackGAN.
431
- """
432
- x_train, y_train, train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,
433
- dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))
434
-
435
- x_test, y_test, test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test,
436
- dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))
437
-
438
- real = np.ones((self.batch_size, 1), dtype='float') * 0.9
439
- fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1
440
-
441
- for epoch in range(self.epochs):
442
- print(f'Epoch: {epoch}')
443
-
444
- gen_loss = []
445
- dis_loss = []
446
-
447
- num_batches = int(x_train.shape[0] / self.batch_size)
448
-
449
- for i in range(num_batches):
450
-
451
- latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
452
- embedding_text = train_embeds[i * self.batch_size:(i + 1) * self.batch_size]
453
- compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)
454
- compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, 128))
455
- compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))
456
-
457
- image_batch = x_train[i * self.batch_size:(i+1) * self.batch_size]
458
- image_batch = (image_batch - 127.5) / 127.5
459
-
460
- gen_images, _ = self.stage1_generator.predict([embedding_text, latent_space])
461
-
462
- discriminator_loss = self.stage1_discriminator.train_on_batch([image_batch, compressed_embedding],
463
- np.reshape(real, (self.batch_size, 1)))
464
-
465
- discriminator_loss_gen = self.stage1_discriminator.train_on_batch([gen_images, compressed_embedding],
466
- np.reshape(fake, (self.batch_size, 1)))
467
-
468
- discriminator_loss_wrong = self.stage1_discriminator.train_on_batch([gen_images[: self.batch_size-1], compressed_embedding[1:]],
469
- np.reshape(fake[1:], (self.batch_size-1, 1)))
470
-
471
- # Discriminator loss
472
- d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_wrong))
473
- dis_loss.append(d_loss)
474
-
475
- print(f'Discriminator Loss: {d_loss}')
476
-
477
- # Generator loss
478
- g_loss = self.stage1_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],
479
- [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])
480
-
481
- print(f'Generator Loss: {g_loss}')
482
- gen_loss.append(g_loss)
483
-
484
- if epoch % 5 == 0:
485
- latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
486
- embedding_batch = test_embeds[0 : self.batch_size]
487
- gen_images, _ = self.stage1_generator.predict_on_batch([embedding_batch, latent_space])
488
-
489
- for i, image in enumerate(gen_images[:10]):
490
- save_image(image, f'test/gen_1_{epoch}_{i}')
491
-
492
- if epoch % 25 == 0:
493
- self.stage1_generator.save_weights('weights/stage1_gen.h5')
494
- self.stage1_discriminator.save_weights("weights/stage1_disc.h5")
495
- self.ca_network.save_weights('weights/stage1_ca.h5')
496
- self.embedding_compressor.save_weights('weights/stage1_embco.h5')
497
- self.stage1_adversarial.save_weights('weights/stage1_adv.h5')
498
-
499
- self.stage1_generator.save_weights('weights/stage1_gen.h5')
500
- self.stage1_discriminator.save_weights("weights/stage1_disc.h5")
501
-
502
-
503
- # In[ ]:
504
-
505
-
506
- stage1 = StackGanStage1()
507
- stage1.train_stage1()
508
-
509
-
510
- # ### Check test folder for gernerated images from Stage1 Generator
511
- # ### Let's Implement Stage 2 Generator
512
-
513
- # In[29]:
514
-
515
-
516
- ############################################################
517
- # Stage 2 Generator Network
518
- ############################################################
519
-
520
- def concat_along_dims(inputs):
521
- """Joins the conditioned text with the encoded image along the dimensions.
522
-
523
- Args:
524
- inputs: consisting of conditioned text and encoded images as [c,x].
525
-
526
- Returns:
527
- Joint block along the dimensions.
528
- """
529
- c = inputs[0]
530
- x = inputs[1]
531
-
532
- c = K.expand_dims(c, axis=1)
533
- c = K.expand_dims(c, axis=1)
534
- c = K.tile(c, [1, 16, 16, 1])
535
- return K.concatenate([c, x], axis = 3)
536
-
537
- def residual_block(input):
538
- """Residual block with plain identity connections.
539
-
540
- Args:
541
- inputs: input layer or an encoded layer
542
-
543
- Returns:
544
- Layer with computed identity mapping.
545
- """
546
- x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,
547
- kernel_initializer='he_uniform')(input)
548
- x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
549
- x = ReLU()(x)
550
-
551
- x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,
552
- kernel_initializer='he_uniform')(x)
553
- x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
554
-
555
- x = add([x, input])
556
- x = ReLU()(x)
557
-
558
- return x
559
-
560
- def build_stage2_generator():
561
- """Build the Stage 2 Generator Network using the conditioning text and images from stage 1.
562
-
563
- Returns:
564
- Stage 2 Generator Model for StackGAN.
565
- """
566
- input_layer1 = Input(shape=(1024,))
567
- input_images = Input(shape=(64, 64, 3))
568
-
569
- # Conditioning Augmentation
570
- ca = Dense(256)(input_layer1)
571
- mls = LeakyReLU(alpha=0.2)(ca)
572
- c = Lambda(conditioning_augmentation)(mls)
573
-
574
- # Downsampling block
575
- x = ZeroPadding2D(padding=(1,1))(input_images)
576
- x = Conv2D(128, kernel_size=(3,3), strides=1, use_bias=False,
577
- kernel_initializer='he_uniform')(x)
578
- x = ReLU()(x)
579
-
580
- x = ZeroPadding2D(padding=(1,1))(x)
581
- x = Conv2D(256, kernel_size=(4,4), strides=2, use_bias=False,
582
- kernel_initializer='he_uniform')(x)
583
- x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
584
- x = ReLU()(x)
585
-
586
- x = ZeroPadding2D(padding=(1,1))(x)
587
- x = Conv2D(512, kernel_size=(4,4), strides=2, use_bias=False,
588
- kernel_initializer='he_uniform')(x)
589
- x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
590
- x = ReLU()(x)
591
-
592
- # Concatenate text conditioning block with the encoded image
593
- concat = concat_along_dims([c, x])
594
-
595
- # Residual Blocks
596
- x = ZeroPadding2D(padding=(1,1))(concat)
597
- x = Conv2D(512, kernel_size=(3,3), use_bias=False, kernel_initializer='he_uniform')(x)
598
- x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
599
- x = ReLU()(x)
600
-
601
- x = residual_block(x)
602
- x = residual_block(x)
603
- x = residual_block(x)
604
- x = residual_block(x)
605
-
606
- # Upsampling Blocks
607
- x = UpSamplingBlock(x, 512)
608
- x = UpSamplingBlock(x, 256)
609
- x = UpSamplingBlock(x, 128)
610
- x = UpSamplingBlock(x, 64)
611
-
612
- x = Conv2D(3, kernel_size=(3,3), padding='same', use_bias=False, kernel_initializer='he_uniform')(x)
613
- x = Activation('tanh')(x)
614
-
615
- stage2_gen = Model(inputs=[input_layer1, input_images], outputs=[x, mls])
616
- return stage2_gen
617
-
618
-
619
- # In[30]:
620
-
621
-
622
- generator_stage2 = build_stage2_generator()
623
- generator_stage2.summary()
624
-
625
-
626
- # In[31]:
627
-
628
-
629
- ############################################################
630
- # Stage 2 Discriminator Network
631
- ############################################################
632
-
633
- def build_stage2_discriminator():
634
- """Builds the Stage 2 Discriminator that uses the 256x256 resolution images from the generator
635
- and the compressed and spatially replicated embeddings.
636
-
637
- Returns:
638
- Stage 2 Discriminator Model for StackGAN.
639
- """
640
- input_layer1 = Input(shape=(256, 256, 3))
641
-
642
- x = Conv2D(64, kernel_size=(4,4), padding='same', strides=2, use_bias=False,
643
- kernel_initializer='he_uniform')(input_layer1)
644
- x = LeakyReLU(alpha=0.2)(x)
645
-
646
- x = ConvBlock(x, 128)
647
- x = ConvBlock(x, 256)
648
- x = ConvBlock(x, 512)
649
- x = ConvBlock(x, 1024)
650
- x = ConvBlock(x, 2048)
651
- x = ConvBlock(x, 1024, (1,1), 1)
652
- x = ConvBlock(x, 512, (1,1), 1, False)
653
-
654
- x1 = ConvBlock(x, 128, (1,1), 1)
655
- x1 = ConvBlock(x1, 128, (3,3), 1)
656
- x1 = ConvBlock(x1, 512, (3,3), 1, False)
657
-
658
- x2 = add([x, x1])
659
- x2 = LeakyReLU(alpha=0.2)(x2)
660
-
661
- # Concatenate compressed and spatially replicated embedding
662
- input_layer2 = Input(shape=(4, 4, 128))
663
- concat = concatenate([x2, input_layer2])
664
-
665
- x3 = Conv2D(512, kernel_size=(1,1), strides=1, padding='same', kernel_initializer='he_uniform')(concat)
666
- x3 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x3)
667
- x3 = LeakyReLU(alpha=0.2)(x3)
668
-
669
- # Flatten and add a FC layer
670
- x3 = Flatten()(x3)
671
- x3 = Dense(1)(x3)
672
- x3 = Activation('sigmoid')(x3)
673
-
674
- stage2_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x3])
675
- return stage2_dis
676
-
677
-
678
- # In[32]:
679
-
680
-
681
- discriminator_stage2 = build_stage2_discriminator()
682
- discriminator_stage2.summary()
683
-
684
-
685
- # In[33]:
686
-
687
-
688
- ############################################################
689
- # Stage 2 Adversarial Model
690
- ############################################################
691
-
692
- def stage2_adversarial_network(stage2_disc, stage2_gen, stage1_gen):
693
- """Stage 2 Adversarial Network.
694
-
695
- Args:
696
- stage2_disc: Stage 2 Discriminator Model.
697
- stage2_gen: Stage 2 Generator Model.
698
- stage1_gen: Stage 1 Generator Model.
699
-
700
- Returns:
701
- Stage 2 Adversarial network.
702
- """
703
- conditioned_embedding = Input(shape=(1024, ))
704
- latent_space = Input(shape=(100, ))
705
- compressed_replicated = Input(shape=(4, 4, 128))
706
-
707
- #the discriminator is trained separately and stage1_gen already trained, and this is the reason why we freeze its layers by setting the property trainable=false
708
- input_images, ca = stage1_gen([conditioned_embedding, latent_space])
709
- stage2_disc.trainable = False
710
- stage1_gen.trainable = False
711
-
712
- images, ca2 = stage2_gen([conditioned_embedding, input_images])
713
- probability = stage2_disc([images, compressed_replicated])
714
-
715
- return Model(inputs=[conditioned_embedding, latent_space, compressed_replicated],
716
- outputs=[probability, ca2])
717
-
718
-
719
- # In[34]:
720
-
721
-
722
- adversarial_stage2 = stage2_adversarial_network(discriminator_stage2, generator_stage2, generator)
723
- adversarial_stage2.summary()
724
-
725
-
726
- # In[35]:
727
-
728
-
729
- class StackGanStage2(object):
730
- """StackGAN Stage 2 class.
731
-
732
- Args:
733
- epochs: Number of epochs
734
- z_dim: Latent space dimensions
735
- batch_size: Batch Size
736
- enable_function: If True, training function is decorated with tf.function
737
- stage2_generator_lr: Learning rate for stage 2 generator
738
- stage2_discriminator_lr: Learning rate for stage 2 discriminator
739
- """
740
- def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage2_generator_lr=0.0002, stage2_discriminator_lr=0.0002):
741
- self.epochs = epochs
742
- self.z_dim = z_dim
743
- self.enable_function = enable_function
744
- self.stage1_generator_lr = stage2_generator_lr
745
- self.stage1_discriminator_lr = stage2_discriminator_lr
746
- self.low_image_size = 64
747
- self.high_image_size = 256
748
- self.conditioning_dim = 128
749
- self.batch_size = batch_size
750
- self.stage2_generator_optimizer = Adam(lr=stage2_generator_lr, beta_1=0.5, beta_2=0.999)
751
- self.stage2_discriminator_optimizer = Adam(lr=stage2_discriminator_lr, beta_1=0.5, beta_2=0.999)
752
- self.stage1_generator = build_stage1_generator()
753
- self.stage1_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)
754
- self.stage1_generator.load_weights('weights/stage1_gen.h5')
755
- self.stage2_generator = build_stage2_generator()
756
- self.stage2_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)
757
-
758
- self.stage2_discriminator = build_stage2_discriminator()
759
- self.stage2_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage2_discriminator_optimizer)
760
-
761
- self.ca_network = build_ca_network()
762
- self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')
763
-
764
- self.embedding_compressor = build_embedding_compressor()
765
- self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')
766
-
767
- self.stage2_adversarial = stage2_adversarial_network(self.stage2_discriminator, self.stage2_generator, self.stage1_generator)
768
- self.stage2_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage2_generator_optimizer)
769
-
770
- self.checkpoint2 = tf.train.Checkpoint(
771
- generator_optimizer=self.stage2_generator_optimizer,
772
- discriminator_optimizer=self.stage2_discriminator_optimizer,
773
- generator=self.stage2_generator,
774
- discriminator=self.stage2_discriminator,
775
- generator1=self.stage1_generator)
776
-
777
- def visualize_stage2(self):
778
- """Running Tensorboard visualizations.
779
- """
780
- tb = TensorBoard(log_dir="logs/".format(time.time()))
781
- tb.set_model(self.stage2_generator)
782
- tb.set_model(self.stage2_discriminator)
783
-
784
- def train_stage2(self):
785
- """Trains Stage 2 StackGAN.
786
- """
787
- x_high_train, y_high_train, high_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,
788
- dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(256, 256))
789
-
790
- x_high_test, y_high_test, high_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test,
791
- dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(256, 256))
792
-
793
- x_low_train, y_low_train, low_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,
794
- dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))
795
-
796
- x_low_test, y_low_test, low_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test,
797
- dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))
798
-
799
- real = np.ones((self.batch_size, 1), dtype='float') * 0.9
800
- fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1
801
-
802
- for epoch in range(self.epochs):
803
- print(f'Epoch: {epoch}')
804
-
805
- gen_loss = []
806
- disc_loss = []
807
-
808
- num_batches = int(x_high_train.shape[0] / self.batch_size)
809
-
810
- for i in range(num_batches):
811
-
812
- latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
813
- embedding_text = high_train_embeds[i * self.batch_size:(i + 1) * self.batch_size]
814
- compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)
815
- compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, self.conditioning_dim))
816
- compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))
817
-
818
- image_batch = x_high_train[i * self.batch_size:(i+1) * self.batch_size]
819
- image_batch = (image_batch - 127.5) / 127.5
820
-
821
- low_res_fakes, _ = self.stage1_generator.predict([embedding_text, latent_space], verbose=3)
822
- high_res_fakes, _ = self.stage2_generator.predict([embedding_text, low_res_fakes], verbose=3)
823
-
824
- discriminator_loss = self.stage2_discriminator.train_on_batch([image_batch, compressed_embedding],
825
- np.reshape(real, (self.batch_size, 1)))
826
-
827
- discriminator_loss_gen = self.stage2_discriminator.train_on_batch([high_res_fakes, compressed_embedding],
828
- np.reshape(fake, (self.batch_size, 1)))
829
-
830
- discriminator_loss_fake = self.stage2_discriminator.train_on_batch([image_batch[:(self.batch_size-1)], compressed_embedding[1:]],
831
- np.reshape(fake[1:], (self.batch_size - 1, 1)))
832
-
833
- d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_fake))
834
- disc_loss.append(d_loss)
835
-
836
- print(f'Discriminator Loss: {d_loss}')
837
-
838
- g_loss = self.stage2_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],
839
- [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])
840
- gen_loss.append(g_loss)
841
-
842
- print(f'Generator Loss: {g_loss}')
843
-
844
- if epoch % 5 == 0:
845
- latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
846
- embedding_batch = high_test_embeds[0 : self.batch_size]
847
-
848
- low_fake_images, _ = self.stage1_generator.predict([embedding_batch, latent_space], verbose=3)
849
- high_fake_images, _ = self.stage2_generator.predict([embedding_batch, low_fake_images], verbose=3)
850
-
851
- for i, image in enumerate(high_fake_images[:10]):
852
- save_image(image, f'results_stage2/gen_{epoch}_{i}.png')
853
-
854
- if epoch % 10 == 0:
855
- self.stage2_generator.save_weights('weights/stage2_gen.h5')
856
- self.stage2_discriminator.save_weights("weights/stage2_disc.h5")
857
- self.ca_network.save_weights('weights/stage2_ca.h5')
858
- self.embedding_compressor.save_weights('weights/stage2_embco.h5')
859
- self.stage2_adversarial.save_weights('weights/stage2_adv.h5')
860
-
861
- self.stage2_generator.save_weights('weights/stage2_gen.h5')
862
- self.stage2_discriminator.save_weights("weights/stage2_disc.h5")
863
-
864
-
865
- # In[ ]:
866
-
867
-
868
- stage2 = StackGanStage2()
869
- stage2.train_stage2()
870
-