Truptidand commited on
Commit
e633c87
·
1 Parent(s): c180ca7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +870 -0
app.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ import os
12
+ import pickle
13
+ import time
14
+ import random
15
+
16
+
17
+ # In[8]:
18
+
19
+
20
+ import PIL
21
+ from PIL import Image
22
+ import keras.backend as K
23
+ import tensorflow as tf
24
+ from tensorflow import keras
25
+ from keras.optimizers import Adam
26
+ from keras.models import Sequential
27
+ from keras import layers,Model,Input
28
+ from keras.layers import Lambda,Reshape,UpSampling2D,ReLU,add,ZeroPadding2D
29
+ from keras.layers import Activation,BatchNormalization,Concatenate,concatenate
30
+ from keras.layers import Dense,Conv2D,Flatten,Dropout,LeakyReLU
31
+ from keras.preprocessing.image import ImageDataGenerator
32
+
33
+
34
+ # ### Conditioning Augmentation Network
35
+
36
+ # In[3]:
37
+
38
+
39
+ # conditioned by the text.
40
+ def conditioning_augmentation(x):
41
+ """The mean_logsigma passed as argument is converted into the text conditioning variable.
42
+
43
+ Args:
44
+ x: The output of the text embedding passed through a FC layer with LeakyReLU non-linearity.
45
+
46
+ Returns:
47
+ c: The text conditioning variable after computation.
48
+ """
49
+ mean = x[:, :128]
50
+ log_sigma = x[:, 128:]
51
+
52
+ stddev = tf.math.exp(log_sigma)
53
+ epsilon = K.random_normal(shape=K.constant((mean.shape[1], ), dtype='int32'))
54
+ c = mean + stddev * epsilon
55
+ return c
56
+
57
+ def build_ca_network():
58
+ """Builds the conditioning augmentation network.
59
+ """
60
+ input_layer1 = Input(shape=(1024,)) #size of the vocabulary in the text data
61
+ mls = Dense(256)(input_layer1)
62
+ mls = LeakyReLU(alpha=0.2)(mls)
63
+ ca = Lambda(conditioning_augmentation)(mls)
64
+ return Model(inputs=[input_layer1], outputs=[ca])
65
+
66
+
67
+ # ### Stage 1 Generator Network
68
+
69
+ # In[4]:
70
+
71
+
72
+ def UpSamplingBlock(x, num_kernels):
73
+ """An Upsample block with Upsampling2D, Conv2D, BatchNormalization and a ReLU activation.
74
+
75
+ Args:
76
+ x: The preceding layer as input.
77
+ num_kernels: Number of kernels for the Conv2D layer.
78
+
79
+ Returns:
80
+ x: The final activation layer after the Upsampling block.
81
+ """
82
+ x = UpSampling2D(size=(2,2))(x)
83
+ x = Conv2D(num_kernels, kernel_size=(3,3), padding='same', strides=1, use_bias=False,
84
+ kernel_initializer='he_uniform')(x)
85
+ x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x) #prevent from mode collapse
86
+ x = ReLU()(x)
87
+ return x
88
+
89
+
90
+ def build_stage1_generator():
91
+
92
+ input_layer1 = Input(shape=(1024,))
93
+ ca = Dense(256)(input_layer1)
94
+ ca = LeakyReLU(alpha=0.2)(ca)
95
+
96
+ # Obtain the conditioned text
97
+ c = Lambda(conditioning_augmentation)(ca)
98
+
99
+ input_layer2 = Input(shape=(100,))
100
+ concat = Concatenate(axis=1)([c, input_layer2])
101
+
102
+ x = Dense(16384, use_bias=False)(concat)
103
+ x = ReLU()(x)
104
+ x = Reshape((4, 4, 1024), input_shape=(16384,))(x)
105
+
106
+ x = UpSamplingBlock(x, 512)
107
+ x = UpSamplingBlock(x, 256)
108
+ x = UpSamplingBlock(x, 128)
109
+ x = UpSamplingBlock(x, 64) # upsampled our image to 64*64*3
110
+
111
+ x = Conv2D(3, kernel_size=3, padding='same', strides=1, use_bias=False,
112
+ kernel_initializer='he_uniform')(x)
113
+ x = Activation('tanh')(x)
114
+
115
+ stage1_gen = Model(inputs=[input_layer1, input_layer2], outputs=[x, ca])
116
+ return stage1_gen
117
+
118
+
119
+ # In[5]:
120
+
121
+
122
+ generator = build_stage1_generator()
123
+ generator.summary()
124
+
125
+
126
+ # ### Stage 1 Discriminator Network
127
+
128
+ # In[9]:
129
+
130
+
131
+ def ConvBlock(x, num_kernels, kernel_size=(4,4), strides=2, activation=True):
132
+ """A ConvBlock with a Conv2D, BatchNormalization and LeakyReLU activation.
133
+
134
+ Args:
135
+ x: The preceding layer as input.
136
+ num_kernels: Number of kernels for the Conv2D layer.
137
+
138
+ Returns:
139
+ x: The final activation layer after the ConvBlock block.
140
+ """
141
+ x = Conv2D(num_kernels, kernel_size=kernel_size, padding='same', strides=strides, use_bias=False,
142
+ kernel_initializer='he_uniform')(x)
143
+ x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
144
+
145
+ if activation:
146
+ x = LeakyReLU(alpha=0.2)(x)
147
+ return x
148
+
149
+
150
+ def build_embedding_compressor():
151
+ """Build embedding compressor model
152
+ """
153
+ input_layer1 = Input(shape=(1024,))
154
+ x = Dense(128)(input_layer1)
155
+ x = ReLU()(x)
156
+
157
+ model = Model(inputs=[input_layer1], outputs=[x])
158
+ return model
159
+
160
+ # the discriminator is fed with two inputs, the feature from Generator and the text embedding
161
+ def build_stage1_discriminator():
162
+ """Builds the Stage 1 Discriminator that uses the 64x64 resolution images from the generator
163
+ and the compressed and spatially replicated embedding.
164
+
165
+ Returns:
166
+ Stage 1 Discriminator Model for StackGAN.
167
+ """
168
+ input_layer1 = Input(shape=(64, 64, 3))
169
+
170
+ x = Conv2D(64, kernel_size=(4,4), strides=2, padding='same', use_bias=False,
171
+ kernel_initializer='he_uniform')(input_layer1)
172
+ x = LeakyReLU(alpha=0.2)(x)
173
+
174
+ x = ConvBlock(x, 128)
175
+ x = ConvBlock(x, 256)
176
+ x = ConvBlock(x, 512)
177
+
178
+ # Obtain the compressed and spatially replicated text embedding
179
+ input_layer2 = Input(shape=(4, 4, 128)) #2nd input to discriminator, text embedding
180
+ concat = concatenate([x, input_layer2])
181
+
182
+ x1 = Conv2D(512, kernel_size=(1,1), padding='same', strides=1, use_bias=False,
183
+ kernel_initializer='he_uniform')(concat)
184
+ x1 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
185
+ x1 = LeakyReLU(alpha=0.2)(x)
186
+
187
+ # Flatten and add a FC layer to predict.
188
+ x1 = Flatten()(x1)
189
+ x1 = Dense(1)(x1)
190
+ x1 = Activation('sigmoid')(x1)
191
+
192
+ stage1_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x1])
193
+ return stage1_dis
194
+
195
+
196
+ # In[10]:
197
+
198
+
199
+ discriminator = build_stage1_discriminator()
200
+ discriminator.summary()
201
+
202
+
203
+ # ### Stage 1 Adversarial Model (Building a GAN)
204
+
205
+ # In[11]:
206
+
207
+
208
+ # Building GAN with Generator and Discriminator
209
+
210
+ def build_adversarial(generator_model, discriminator_model):
211
+ """Stage 1 Adversarial model.
212
+
213
+ Args:
214
+ generator_model: Stage 1 Generator Model
215
+ discriminator_model: Stage 1 Discriminator Model
216
+
217
+ Returns:
218
+ Adversarial Model.
219
+ """
220
+ input_layer1 = Input(shape=(1024,))
221
+ input_layer2 = Input(shape=(100,))
222
+ input_layer3 = Input(shape=(4, 4, 128))
223
+
224
+ x, ca = generator_model([input_layer1, input_layer2]) #text,noise
225
+
226
+ discriminator_model.trainable = False
227
+
228
+ probabilities = discriminator_model([x, input_layer3])
229
+ adversarial_model = Model(inputs=[input_layer1, input_layer2, input_layer3], outputs=[probabilities, ca])
230
+ return adversarial_model
231
+
232
+
233
+ # In[12]:
234
+
235
+
236
+ ganstage1 = build_adversarial(generator, discriminator)
237
+ ganstage1.summary()
238
+
239
+
240
+ # ### Train Utilities
241
+
242
+ # In[13]:
243
+
244
+
245
+ def checkpoint_prefix():
246
+ checkpoint_dir = './training_checkpoints'
247
+ checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
248
+
249
+ return checkpoint_prefix
250
+
251
+ def adversarial_loss(y_true, y_pred):
252
+ mean = y_pred[:, :128]
253
+ ls = y_pred[:, 128:]
254
+ loss = -ls + 0.5 * (-1 + tf.math.exp(2.0 * ls) + tf.math.square(mean))
255
+ loss = K.mean(loss)
256
+ return loss
257
+
258
+ def normalize(input_image, real_image):
259
+ input_image = (input_image / 127.5) - 1
260
+ real_image = (real_image / 127.5) - 1
261
+
262
+ return input_image, real_image
263
+
264
+ def load_class_ids_filenames(class_id_path, filename_path):
265
+ with open(class_id_path, 'rb') as file:
266
+ class_id = pickle.load(file, encoding='latin1')
267
+
268
+ with open(filename_path, 'rb') as file:
269
+ filename = pickle.load(file, encoding='latin1')
270
+
271
+ return class_id, filename
272
+
273
+ def load_text_embeddings(text_embeddings):
274
+ with open(text_embeddings, 'rb') as file:
275
+ embeds = pickle.load(file, encoding='latin1')
276
+ embeds = np.array(embeds)
277
+
278
+ return embeds
279
+
280
+ def load_bbox(data_path):
281
+ bbox_path = data_path + '/bounding_boxes.txt'
282
+ image_path = data_path + '/images.txt'
283
+ bbox_df = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)
284
+ filename_df = pd.read_csv(image_path, delim_whitespace=True, header=None)
285
+
286
+ filenames = filename_df[1].tolist()
287
+ bbox_dict = {i[:-4]:[] for i in filenames[:2]}
288
+
289
+ for i in range(0, len(filenames)):
290
+ bbox = bbox_df.iloc[i][1:].tolist()
291
+ dict_key = filenames[i][:-4]
292
+ bbox_dict[dict_key] = bbox
293
+
294
+ return bbox_dict
295
+
296
+ def load_images(image_path, bounding_box, size):
297
+ """Crops the image to the bounding box and then resizes it.
298
+ """
299
+ image = Image.open(image_path).convert('RGB')
300
+ w, h = image.size
301
+ if bounding_box is not None:
302
+ r = int(np.maximum(bounding_box[2], bounding_box[3]) * 0.75)
303
+ c_x = int((bounding_box[0] + bounding_box[2]) / 2)
304
+ c_y = int((bounding_box[1] + bounding_box[3]) / 2)
305
+ y1 = np.maximum(0, c_y - r)
306
+ y2 = np.minimum(h, c_y + r)
307
+ x1 = np.maximum(0, c_x - r)
308
+ x2 = np.minimum(w, c_x + r)
309
+ image = image.crop([x1, y1, x2, y2])
310
+
311
+ image = image.resize(size, PIL.Image.BILINEAR)
312
+ return image
313
+
314
+ def load_data(filename_path, class_id_path, dataset_path, embeddings_path, size):
315
+ """Loads the Dataset.
316
+ """
317
+ data_dir = "D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds"
318
+ train_dir = data_dir + "/train"
319
+ test_dir = data_dir + "/test"
320
+ embeddings_path_train = train_dir + "/char-CNN-RNN-embeddings.pickle"
321
+ embeddings_path_test = test_dir + "/char-CNN-RNN-embeddings.pickle"
322
+ filename_path_train = train_dir + "/filenames.pickle"
323
+ filename_path_test = test_dir + "/filenames.pickle"
324
+ class_id_path_train = train_dir + "/class_info.pickle"
325
+ class_id_path_test = test_dir + "/class_info.pickle"
326
+ dataset_path = "D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011"
327
+ class_id, filenames = load_class_ids_filenames(class_id_path, filename_path)
328
+ embeddings = load_text_embeddings(embeddings_path)
329
+ bbox_dict = load_bbox(dataset_path)
330
+
331
+ x, y, embeds = [], [], []
332
+
333
+ for i, filename in enumerate(filenames):
334
+ bbox = bbox_dict[filename]
335
+
336
+ try:
337
+ image_path = f'{dataset_path}/images/{filename}.jpg'
338
+ image = load_images(image_path, bbox, size)
339
+ e = embeddings[i, :, :]
340
+ embed_index = np.random.randint(0, e.shape[0] - 1)
341
+ embed = e[embed_index, :]
342
+
343
+ x.append(np.array(image))
344
+ y.append(class_id[i])
345
+ embeds.append(embed)
346
+
347
+ except Exception as e:
348
+ print(f'{e}')
349
+
350
+ x = np.array(x)
351
+ y = np.array(y)
352
+ embeds = np.array(embeds)
353
+
354
+ return x, y, embeds
355
+
356
+ def save_image(file, save_path):
357
+ """Saves the image at the specified file path.
358
+ """
359
+ image = plt.figure()
360
+ ax = image.add_subplot(1,1,1)
361
+ ax.imshow(file)
362
+ ax.axis("off")
363
+ plt.savefig(save_path)
364
+
365
+
366
+ # In[28]:
367
+
368
+
369
+ ############################################################
370
+ # StackGAN class
371
+ ############################################################
372
+
373
+ class StackGanStage1(object):
374
+ """StackGAN Stage 1 class."""
375
+
376
+ data_dir = "D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds"
377
+ train_dir = data_dir + "/train"
378
+ test_dir = data_dir + "/test"
379
+ embeddings_path_train = train_dir + "/char-CNN-RNN-embeddings.pickle"
380
+ embeddings_path_test = test_dir + "/char-CNN-RNN-embeddings.pickle"
381
+ filename_path_train = train_dir + "/filenames.pickle"
382
+ filename_path_test = test_dir + "/filenames.pickle"
383
+ class_id_path_train = train_dir + "/class_info.pickle"
384
+ class_id_path_test = test_dir + "/class_info.pickle"
385
+ dataset_path = "D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011"
386
+ def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage1_generator_lr=0.0002, stage1_discriminator_lr=0.0002):
387
+ self.epochs = epochs
388
+ self.z_dim = z_dim
389
+ self.enable_function = enable_function
390
+ self.stage1_generator_lr = stage1_generator_lr
391
+ self.stage1_discriminator_lr = stage1_discriminator_lr
392
+ self.image_size = 64
393
+ self.conditioning_dim = 128
394
+ self.batch_size = batch_size
395
+
396
+ self.stage1_generator_optimizer = Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)
397
+ self.stage1_discriminator_optimizer = Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
398
+
399
+ self.stage1_generator = build_stage1_generator()
400
+ self.stage1_generator.compile(loss='mse', optimizer=self.stage1_generator_optimizer)
401
+
402
+ self.stage1_discriminator = build_stage1_discriminator()
403
+ self.stage1_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage1_discriminator_optimizer)
404
+
405
+ self.ca_network = build_ca_network()
406
+ self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')
407
+
408
+ self.embedding_compressor = build_embedding_compressor()
409
+ self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')
410
+
411
+ self.stage1_adversarial = build_adversarial(self.stage1_generator, self.stage1_discriminator)
412
+ self.stage1_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage1_generator_optimizer)
413
+
414
+ self.checkpoint1 = tf.train.Checkpoint(
415
+ generator_optimizer=self.stage1_generator_optimizer,
416
+ discriminator_optimizer=self.stage1_discriminator_optimizer,
417
+ generator=self.stage1_generator,
418
+ discriminator=self.stage1_discriminator)
419
+
420
+ def visualize_stage1(self):
421
+ """Running Tensorboard visualizations.
422
+ """
423
+ tb = TensorBoard(log_dir="logs/".format(time.time()))
424
+ tb.set_model(self.stage1_generator)
425
+ tb.set_model(self.stage1_discriminator)
426
+ tb.set_model(self.ca_network)
427
+ tb.set_model(self.embedding_compressor)
428
+
429
+ def train_stage1(self):
430
+ """Trains the stage1 StackGAN.
431
+ """
432
+ x_train, y_train, train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,
433
+ dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))
434
+
435
+ x_test, y_test, test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test,
436
+ dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))
437
+
438
+ real = np.ones((self.batch_size, 1), dtype='float') * 0.9
439
+ fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1
440
+
441
+ for epoch in range(self.epochs):
442
+ print(f'Epoch: {epoch}')
443
+
444
+ gen_loss = []
445
+ dis_loss = []
446
+
447
+ num_batches = int(x_train.shape[0] / self.batch_size)
448
+
449
+ for i in range(num_batches):
450
+
451
+ latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
452
+ embedding_text = train_embeds[i * self.batch_size:(i + 1) * self.batch_size]
453
+ compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)
454
+ compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, 128))
455
+ compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))
456
+
457
+ image_batch = x_train[i * self.batch_size:(i+1) * self.batch_size]
458
+ image_batch = (image_batch - 127.5) / 127.5
459
+
460
+ gen_images, _ = self.stage1_generator.predict([embedding_text, latent_space])
461
+
462
+ discriminator_loss = self.stage1_discriminator.train_on_batch([image_batch, compressed_embedding],
463
+ np.reshape(real, (self.batch_size, 1)))
464
+
465
+ discriminator_loss_gen = self.stage1_discriminator.train_on_batch([gen_images, compressed_embedding],
466
+ np.reshape(fake, (self.batch_size, 1)))
467
+
468
+ discriminator_loss_wrong = self.stage1_discriminator.train_on_batch([gen_images[: self.batch_size-1], compressed_embedding[1:]],
469
+ np.reshape(fake[1:], (self.batch_size-1, 1)))
470
+
471
+ # Discriminator loss
472
+ d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_wrong))
473
+ dis_loss.append(d_loss)
474
+
475
+ print(f'Discriminator Loss: {d_loss}')
476
+
477
+ # Generator loss
478
+ g_loss = self.stage1_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],
479
+ [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])
480
+
481
+ print(f'Generator Loss: {g_loss}')
482
+ gen_loss.append(g_loss)
483
+
484
+ if epoch % 5 == 0:
485
+ latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
486
+ embedding_batch = test_embeds[0 : self.batch_size]
487
+ gen_images, _ = self.stage1_generator.predict_on_batch([embedding_batch, latent_space])
488
+
489
+ for i, image in enumerate(gen_images[:10]):
490
+ save_image(image, f'test/gen_1_{epoch}_{i}')
491
+
492
+ if epoch % 25 == 0:
493
+ self.stage1_generator.save_weights('weights/stage1_gen.h5')
494
+ self.stage1_discriminator.save_weights("weights/stage1_disc.h5")
495
+ self.ca_network.save_weights('weights/stage1_ca.h5')
496
+ self.embedding_compressor.save_weights('weights/stage1_embco.h5')
497
+ self.stage1_adversarial.save_weights('weights/stage1_adv.h5')
498
+
499
+ self.stage1_generator.save_weights('weights/stage1_gen.h5')
500
+ self.stage1_discriminator.save_weights("weights/stage1_disc.h5")
501
+
502
+
503
+ # In[ ]:
504
+
505
+
506
+ stage1 = StackGanStage1()
507
+ stage1.train_stage1()
508
+
509
+
510
+ # ### Check test folder for gernerated images from Stage1 Generator
511
+ # ### Let's Implement Stage 2 Generator
512
+
513
+ # In[29]:
514
+
515
+
516
+ ############################################################
517
+ # Stage 2 Generator Network
518
+ ############################################################
519
+
520
+ def concat_along_dims(inputs):
521
+ """Joins the conditioned text with the encoded image along the dimensions.
522
+
523
+ Args:
524
+ inputs: consisting of conditioned text and encoded images as [c,x].
525
+
526
+ Returns:
527
+ Joint block along the dimensions.
528
+ """
529
+ c = inputs[0]
530
+ x = inputs[1]
531
+
532
+ c = K.expand_dims(c, axis=1)
533
+ c = K.expand_dims(c, axis=1)
534
+ c = K.tile(c, [1, 16, 16, 1])
535
+ return K.concatenate([c, x], axis = 3)
536
+
537
+ def residual_block(input):
538
+ """Residual block with plain identity connections.
539
+
540
+ Args:
541
+ inputs: input layer or an encoded layer
542
+
543
+ Returns:
544
+ Layer with computed identity mapping.
545
+ """
546
+ x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,
547
+ kernel_initializer='he_uniform')(input)
548
+ x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
549
+ x = ReLU()(x)
550
+
551
+ x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,
552
+ kernel_initializer='he_uniform')(x)
553
+ x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
554
+
555
+ x = add([x, input])
556
+ x = ReLU()(x)
557
+
558
+ return x
559
+
560
+ def build_stage2_generator():
561
+ """Build the Stage 2 Generator Network using the conditioning text and images from stage 1.
562
+
563
+ Returns:
564
+ Stage 2 Generator Model for StackGAN.
565
+ """
566
+ input_layer1 = Input(shape=(1024,))
567
+ input_images = Input(shape=(64, 64, 3))
568
+
569
+ # Conditioning Augmentation
570
+ ca = Dense(256)(input_layer1)
571
+ mls = LeakyReLU(alpha=0.2)(ca)
572
+ c = Lambda(conditioning_augmentation)(mls)
573
+
574
+ # Downsampling block
575
+ x = ZeroPadding2D(padding=(1,1))(input_images)
576
+ x = Conv2D(128, kernel_size=(3,3), strides=1, use_bias=False,
577
+ kernel_initializer='he_uniform')(x)
578
+ x = ReLU()(x)
579
+
580
+ x = ZeroPadding2D(padding=(1,1))(x)
581
+ x = Conv2D(256, kernel_size=(4,4), strides=2, use_bias=False,
582
+ kernel_initializer='he_uniform')(x)
583
+ x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
584
+ x = ReLU()(x)
585
+
586
+ x = ZeroPadding2D(padding=(1,1))(x)
587
+ x = Conv2D(512, kernel_size=(4,4), strides=2, use_bias=False,
588
+ kernel_initializer='he_uniform')(x)
589
+ x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
590
+ x = ReLU()(x)
591
+
592
+ # Concatenate text conditioning block with the encoded image
593
+ concat = concat_along_dims([c, x])
594
+
595
+ # Residual Blocks
596
+ x = ZeroPadding2D(padding=(1,1))(concat)
597
+ x = Conv2D(512, kernel_size=(3,3), use_bias=False, kernel_initializer='he_uniform')(x)
598
+ x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
599
+ x = ReLU()(x)
600
+
601
+ x = residual_block(x)
602
+ x = residual_block(x)
603
+ x = residual_block(x)
604
+ x = residual_block(x)
605
+
606
+ # Upsampling Blocks
607
+ x = UpSamplingBlock(x, 512)
608
+ x = UpSamplingBlock(x, 256)
609
+ x = UpSamplingBlock(x, 128)
610
+ x = UpSamplingBlock(x, 64)
611
+
612
+ x = Conv2D(3, kernel_size=(3,3), padding='same', use_bias=False, kernel_initializer='he_uniform')(x)
613
+ x = Activation('tanh')(x)
614
+
615
+ stage2_gen = Model(inputs=[input_layer1, input_images], outputs=[x, mls])
616
+ return stage2_gen
617
+
618
+
619
+ # In[30]:
620
+
621
+
622
+ generator_stage2 = build_stage2_generator()
623
+ generator_stage2.summary()
624
+
625
+
626
+ # In[31]:
627
+
628
+
629
+ ############################################################
630
+ # Stage 2 Discriminator Network
631
+ ############################################################
632
+
633
+ def build_stage2_discriminator():
634
+ """Builds the Stage 2 Discriminator that uses the 256x256 resolution images from the generator
635
+ and the compressed and spatially replicated embeddings.
636
+
637
+ Returns:
638
+ Stage 2 Discriminator Model for StackGAN.
639
+ """
640
+ input_layer1 = Input(shape=(256, 256, 3))
641
+
642
+ x = Conv2D(64, kernel_size=(4,4), padding='same', strides=2, use_bias=False,
643
+ kernel_initializer='he_uniform')(input_layer1)
644
+ x = LeakyReLU(alpha=0.2)(x)
645
+
646
+ x = ConvBlock(x, 128)
647
+ x = ConvBlock(x, 256)
648
+ x = ConvBlock(x, 512)
649
+ x = ConvBlock(x, 1024)
650
+ x = ConvBlock(x, 2048)
651
+ x = ConvBlock(x, 1024, (1,1), 1)
652
+ x = ConvBlock(x, 512, (1,1), 1, False)
653
+
654
+ x1 = ConvBlock(x, 128, (1,1), 1)
655
+ x1 = ConvBlock(x1, 128, (3,3), 1)
656
+ x1 = ConvBlock(x1, 512, (3,3), 1, False)
657
+
658
+ x2 = add([x, x1])
659
+ x2 = LeakyReLU(alpha=0.2)(x2)
660
+
661
+ # Concatenate compressed and spatially replicated embedding
662
+ input_layer2 = Input(shape=(4, 4, 128))
663
+ concat = concatenate([x2, input_layer2])
664
+
665
+ x3 = Conv2D(512, kernel_size=(1,1), strides=1, padding='same', kernel_initializer='he_uniform')(concat)
666
+ x3 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x3)
667
+ x3 = LeakyReLU(alpha=0.2)(x3)
668
+
669
+ # Flatten and add a FC layer
670
+ x3 = Flatten()(x3)
671
+ x3 = Dense(1)(x3)
672
+ x3 = Activation('sigmoid')(x3)
673
+
674
+ stage2_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x3])
675
+ return stage2_dis
676
+
677
+
678
+ # In[32]:
679
+
680
+
681
+ discriminator_stage2 = build_stage2_discriminator()
682
+ discriminator_stage2.summary()
683
+
684
+
685
+ # In[33]:
686
+
687
+
688
+ ############################################################
689
+ # Stage 2 Adversarial Model
690
+ ############################################################
691
+
692
+ def stage2_adversarial_network(stage2_disc, stage2_gen, stage1_gen):
693
+ """Stage 2 Adversarial Network.
694
+
695
+ Args:
696
+ stage2_disc: Stage 2 Discriminator Model.
697
+ stage2_gen: Stage 2 Generator Model.
698
+ stage1_gen: Stage 1 Generator Model.
699
+
700
+ Returns:
701
+ Stage 2 Adversarial network.
702
+ """
703
+ conditioned_embedding = Input(shape=(1024, ))
704
+ latent_space = Input(shape=(100, ))
705
+ compressed_replicated = Input(shape=(4, 4, 128))
706
+
707
+ #the discriminator is trained separately and stage1_gen already trained, and this is the reason why we freeze its layers by setting the property trainable=false
708
+ input_images, ca = stage1_gen([conditioned_embedding, latent_space])
709
+ stage2_disc.trainable = False
710
+ stage1_gen.trainable = False
711
+
712
+ images, ca2 = stage2_gen([conditioned_embedding, input_images])
713
+ probability = stage2_disc([images, compressed_replicated])
714
+
715
+ return Model(inputs=[conditioned_embedding, latent_space, compressed_replicated],
716
+ outputs=[probability, ca2])
717
+
718
+
719
+ # In[34]:
720
+
721
+
722
+ adversarial_stage2 = stage2_adversarial_network(discriminator_stage2, generator_stage2, generator)
723
+ adversarial_stage2.summary()
724
+
725
+
726
+ # In[35]:
727
+
728
+
729
+ class StackGanStage2(object):
730
+ """StackGAN Stage 2 class.
731
+
732
+ Args:
733
+ epochs: Number of epochs
734
+ z_dim: Latent space dimensions
735
+ batch_size: Batch Size
736
+ enable_function: If True, training function is decorated with tf.function
737
+ stage2_generator_lr: Learning rate for stage 2 generator
738
+ stage2_discriminator_lr: Learning rate for stage 2 discriminator
739
+ """
740
+ def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage2_generator_lr=0.0002, stage2_discriminator_lr=0.0002):
741
+ self.epochs = epochs
742
+ self.z_dim = z_dim
743
+ self.enable_function = enable_function
744
+ self.stage1_generator_lr = stage2_generator_lr
745
+ self.stage1_discriminator_lr = stage2_discriminator_lr
746
+ self.low_image_size = 64
747
+ self.high_image_size = 256
748
+ self.conditioning_dim = 128
749
+ self.batch_size = batch_size
750
+ self.stage2_generator_optimizer = Adam(lr=stage2_generator_lr, beta_1=0.5, beta_2=0.999)
751
+ self.stage2_discriminator_optimizer = Adam(lr=stage2_discriminator_lr, beta_1=0.5, beta_2=0.999)
752
+ self.stage1_generator = build_stage1_generator()
753
+ self.stage1_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)
754
+ self.stage1_generator.load_weights('weights/stage1_gen.h5')
755
+ self.stage2_generator = build_stage2_generator()
756
+ self.stage2_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)
757
+
758
+ self.stage2_discriminator = build_stage2_discriminator()
759
+ self.stage2_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage2_discriminator_optimizer)
760
+
761
+ self.ca_network = build_ca_network()
762
+ self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')
763
+
764
+ self.embedding_compressor = build_embedding_compressor()
765
+ self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')
766
+
767
+ self.stage2_adversarial = stage2_adversarial_network(self.stage2_discriminator, self.stage2_generator, self.stage1_generator)
768
+ self.stage2_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage2_generator_optimizer)
769
+
770
+ self.checkpoint2 = tf.train.Checkpoint(
771
+ generator_optimizer=self.stage2_generator_optimizer,
772
+ discriminator_optimizer=self.stage2_discriminator_optimizer,
773
+ generator=self.stage2_generator,
774
+ discriminator=self.stage2_discriminator,
775
+ generator1=self.stage1_generator)
776
+
777
+ def visualize_stage2(self):
778
+ """Running Tensorboard visualizations.
779
+ """
780
+ tb = TensorBoard(log_dir="logs/".format(time.time()))
781
+ tb.set_model(self.stage2_generator)
782
+ tb.set_model(self.stage2_discriminator)
783
+
784
+ def train_stage2(self):
785
+ """Trains Stage 2 StackGAN.
786
+ """
787
+ x_high_train, y_high_train, high_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,
788
+ dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(256, 256))
789
+
790
+ x_high_test, y_high_test, high_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test,
791
+ dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(256, 256))
792
+
793
+ x_low_train, y_low_train, low_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,
794
+ dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))
795
+
796
+ x_low_test, y_low_test, low_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test,
797
+ dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))
798
+
799
+ real = np.ones((self.batch_size, 1), dtype='float') * 0.9
800
+ fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1
801
+
802
+ for epoch in range(self.epochs):
803
+ print(f'Epoch: {epoch}')
804
+
805
+ gen_loss = []
806
+ disc_loss = []
807
+
808
+ num_batches = int(x_high_train.shape[0] / self.batch_size)
809
+
810
+ for i in range(num_batches):
811
+
812
+ latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
813
+ embedding_text = high_train_embeds[i * self.batch_size:(i + 1) * self.batch_size]
814
+ compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)
815
+ compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, self.conditioning_dim))
816
+ compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))
817
+
818
+ image_batch = x_high_train[i * self.batch_size:(i+1) * self.batch_size]
819
+ image_batch = (image_batch - 127.5) / 127.5
820
+
821
+ low_res_fakes, _ = self.stage1_generator.predict([embedding_text, latent_space], verbose=3)
822
+ high_res_fakes, _ = self.stage2_generator.predict([embedding_text, low_res_fakes], verbose=3)
823
+
824
+ discriminator_loss = self.stage2_discriminator.train_on_batch([image_batch, compressed_embedding],
825
+ np.reshape(real, (self.batch_size, 1)))
826
+
827
+ discriminator_loss_gen = self.stage2_discriminator.train_on_batch([high_res_fakes, compressed_embedding],
828
+ np.reshape(fake, (self.batch_size, 1)))
829
+
830
+ discriminator_loss_fake = self.stage2_discriminator.train_on_batch([image_batch[:(self.batch_size-1)], compressed_embedding[1:]],
831
+ np.reshape(fake[1:], (self.batch_size - 1, 1)))
832
+
833
+ d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_fake))
834
+ disc_loss.append(d_loss)
835
+
836
+ print(f'Discriminator Loss: {d_loss}')
837
+
838
+ g_loss = self.stage2_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],
839
+ [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])
840
+ gen_loss.append(g_loss)
841
+
842
+ print(f'Generator Loss: {g_loss}')
843
+
844
+ if epoch % 5 == 0:
845
+ latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
846
+ embedding_batch = high_test_embeds[0 : self.batch_size]
847
+
848
+ low_fake_images, _ = self.stage1_generator.predict([embedding_batch, latent_space], verbose=3)
849
+ high_fake_images, _ = self.stage2_generator.predict([embedding_batch, low_fake_images], verbose=3)
850
+
851
+ for i, image in enumerate(high_fake_images[:10]):
852
+ save_image(image, f'results_stage2/gen_{epoch}_{i}.png')
853
+
854
+ if epoch % 10 == 0:
855
+ self.stage2_generator.save_weights('weights/stage2_gen.h5')
856
+ self.stage2_discriminator.save_weights("weights/stage2_disc.h5")
857
+ self.ca_network.save_weights('weights/stage2_ca.h5')
858
+ self.embedding_compressor.save_weights('weights/stage2_embco.h5')
859
+ self.stage2_adversarial.save_weights('weights/stage2_adv.h5')
860
+
861
+ self.stage2_generator.save_weights('weights/stage2_gen.h5')
862
+ self.stage2_discriminator.save_weights("weights/stage2_disc.h5")
863
+
864
+
865
+ # In[ ]:
866
+
867
+
868
+ stage2 = StackGanStage2()
869
+ stage2.train_stage2()
870
+