Truptidand commited on
Commit
c180ca7
·
1 Parent(s): f13cdf6

Delete GAN.ipynb

Browse files
Files changed (1) hide show
  1. GAN.ipynb +0 -1514
GAN.ipynb DELETED
@@ -1,1514 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "a3677b66",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import numpy as np\n",
11
- "import pandas as pd\n",
12
- "import matplotlib.pyplot as plt\n",
13
- "import seaborn as sns\n",
14
- "import os\n",
15
- "import pickle\n",
16
- "import time\n",
17
- "import random"
18
- ]
19
- },
20
- {
21
- "cell_type": "code",
22
- "execution_count": 8,
23
- "id": "76ece7f8",
24
- "metadata": {},
25
- "outputs": [],
26
- "source": [
27
- "import PIL\n",
28
- "from PIL import Image\n",
29
- "import keras.backend as K\n",
30
- "import tensorflow as tf\n",
31
- "from tensorflow import keras\n",
32
- "from keras.optimizers import Adam\n",
33
- "from keras.models import Sequential\n",
34
- "from keras import layers,Model,Input\n",
35
- "from keras.layers import Lambda,Reshape,UpSampling2D,ReLU,add,ZeroPadding2D\n",
36
- "from keras.layers import Activation,BatchNormalization,Concatenate,concatenate\n",
37
- "from keras.layers import Dense,Conv2D,Flatten,Dropout,LeakyReLU\n",
38
- "from keras.preprocessing.image import ImageDataGenerator"
39
- ]
40
- },
41
- {
42
- "cell_type": "markdown",
43
- "id": "b8980cd5",
44
- "metadata": {},
45
- "source": [
46
- "### Conditioning Augmentation Network"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": 3,
52
- "id": "d3027cda",
53
- "metadata": {},
54
- "outputs": [],
55
- "source": [
56
- "# conditioned by the text.\n",
57
- "def conditioning_augmentation(x):\n",
58
- " \"\"\"The mean_logsigma passed as argument is converted into the text conditioning variable.\n",
59
- "\n",
60
- " Args:\n",
61
- " x: The output of the text embedding passed through a FC layer with LeakyReLU non-linearity.\n",
62
- "\n",
63
- " Returns:\n",
64
- " c: The text conditioning variable after computation.\n",
65
- " \"\"\"\n",
66
- " mean = x[:, :128]\n",
67
- " log_sigma = x[:, 128:]\n",
68
- "\n",
69
- " stddev = tf.math.exp(log_sigma)\n",
70
- " epsilon = K.random_normal(shape=K.constant((mean.shape[1], ), dtype='int32'))\n",
71
- " c = mean + stddev * epsilon\n",
72
- " return c\n",
73
- "\n",
74
- "def build_ca_network():\n",
75
- " \"\"\"Builds the conditioning augmentation network.\n",
76
- " \"\"\"\n",
77
- " input_layer1 = Input(shape=(1024,)) #size of the vocabulary in the text data\n",
78
- " mls = Dense(256)(input_layer1)\n",
79
- " mls = LeakyReLU(alpha=0.2)(mls)\n",
80
- " ca = Lambda(conditioning_augmentation)(mls)\n",
81
- " return Model(inputs=[input_layer1], outputs=[ca]) "
82
- ]
83
- },
84
- {
85
- "cell_type": "markdown",
86
- "id": "87340e8b",
87
- "metadata": {},
88
- "source": [
89
- "### Stage 1 Generator Network"
90
- ]
91
- },
92
- {
93
- "cell_type": "code",
94
- "execution_count": 4,
95
- "id": "c430524d",
96
- "metadata": {},
97
- "outputs": [],
98
- "source": [
99
- "def UpSamplingBlock(x, num_kernels):\n",
100
- " \"\"\"An Upsample block with Upsampling2D, Conv2D, BatchNormalization and a ReLU activation.\n",
101
- "\n",
102
- " Args:\n",
103
- " x: The preceding layer as input.\n",
104
- " num_kernels: Number of kernels for the Conv2D layer.\n",
105
- "\n",
106
- " Returns:\n",
107
- " x: The final activation layer after the Upsampling block.\n",
108
- " \"\"\"\n",
109
- " x = UpSampling2D(size=(2,2))(x)\n",
110
- " x = Conv2D(num_kernels, kernel_size=(3,3), padding='same', strides=1, use_bias=False,\n",
111
- " kernel_initializer='he_uniform')(x)\n",
112
- " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x) #prevent from mode collapse\n",
113
- " x = ReLU()(x)\n",
114
- " return x\n",
115
- "\n",
116
- "\n",
117
- "def build_stage1_generator():\n",
118
- "\n",
119
- " input_layer1 = Input(shape=(1024,))\n",
120
- " ca = Dense(256)(input_layer1)\n",
121
- " ca = LeakyReLU(alpha=0.2)(ca)\n",
122
- "\n",
123
- " # Obtain the conditioned text\n",
124
- " c = Lambda(conditioning_augmentation)(ca)\n",
125
- "\n",
126
- " input_layer2 = Input(shape=(100,))\n",
127
- " concat = Concatenate(axis=1)([c, input_layer2]) \n",
128
- "\n",
129
- " x = Dense(16384, use_bias=False)(concat) \n",
130
- " x = ReLU()(x)\n",
131
- " x = Reshape((4, 4, 1024), input_shape=(16384,))(x)\n",
132
- "\n",
133
- " x = UpSamplingBlock(x, 512) \n",
134
- " x = UpSamplingBlock(x, 256)\n",
135
- " x = UpSamplingBlock(x, 128)\n",
136
- " x = UpSamplingBlock(x, 64) # upsampled our image to 64*64*3 \n",
137
- "\n",
138
- " x = Conv2D(3, kernel_size=3, padding='same', strides=1, use_bias=False,\n",
139
- " kernel_initializer='he_uniform')(x)\n",
140
- " x = Activation('tanh')(x)\n",
141
- "\n",
142
- " stage1_gen = Model(inputs=[input_layer1, input_layer2], outputs=[x, ca]) \n",
143
- " return stage1_gen"
144
- ]
145
- },
146
- {
147
- "cell_type": "code",
148
- "execution_count": 5,
149
- "id": "0febcb4f",
150
- "metadata": {},
151
- "outputs": [
152
- {
153
- "name": "stdout",
154
- "output_type": "stream",
155
- "text": [
156
- "Model: \"model\"\n",
157
- "__________________________________________________________________________________________________\n",
158
- " Layer (type) Output Shape Param # Connected to \n",
159
- "==================================================================================================\n",
160
- " input_1 (InputLayer) [(None, 1024)] 0 [] \n",
161
- " \n",
162
- " dense (Dense) (None, 256) 262400 ['input_1[0][0]'] \n",
163
- " \n",
164
- " leaky_re_lu (LeakyReLU) (None, 256) 0 ['dense[0][0]'] \n",
165
- " \n",
166
- " lambda (Lambda) (None, 128) 0 ['leaky_re_lu[0][0]'] \n",
167
- " \n",
168
- " input_2 (InputLayer) [(None, 100)] 0 [] \n",
169
- " \n",
170
- " concatenate (Concatenate) (None, 228) 0 ['lambda[0][0]', \n",
171
- " 'input_2[0][0]'] \n",
172
- " \n",
173
- " dense_1 (Dense) (None, 16384) 3735552 ['concatenate[0][0]'] \n",
174
- " \n",
175
- " re_lu (ReLU) (None, 16384) 0 ['dense_1[0][0]'] \n",
176
- " \n",
177
- " reshape (Reshape) (None, 4, 4, 1024) 0 ['re_lu[0][0]'] \n",
178
- " \n",
179
- " up_sampling2d (UpSampling2D) (None, 8, 8, 1024) 0 ['reshape[0][0]'] \n",
180
- " \n",
181
- " conv2d (Conv2D) (None, 8, 8, 512) 4718592 ['up_sampling2d[0][0]'] \n",
182
- " \n",
183
- " batch_normalization (BatchNorm (None, 8, 8, 512) 2048 ['conv2d[0][0]'] \n",
184
- " alization) \n",
185
- " \n",
186
- " re_lu_1 (ReLU) (None, 8, 8, 512) 0 ['batch_normalization[0][0]'] \n",
187
- " \n",
188
- " up_sampling2d_1 (UpSampling2D) (None, 16, 16, 512) 0 ['re_lu_1[0][0]'] \n",
189
- " \n",
190
- " conv2d_1 (Conv2D) (None, 16, 16, 256) 1179648 ['up_sampling2d_1[0][0]'] \n",
191
- " \n",
192
- " batch_normalization_1 (BatchNo (None, 16, 16, 256) 1024 ['conv2d_1[0][0]'] \n",
193
- " rmalization) \n",
194
- " \n",
195
- " re_lu_2 (ReLU) (None, 16, 16, 256) 0 ['batch_normalization_1[0][0]'] \n",
196
- " \n",
197
- " up_sampling2d_2 (UpSampling2D) (None, 32, 32, 256) 0 ['re_lu_2[0][0]'] \n",
198
- " \n",
199
- " conv2d_2 (Conv2D) (None, 32, 32, 128) 294912 ['up_sampling2d_2[0][0]'] \n",
200
- " \n",
201
- " batch_normalization_2 (BatchNo (None, 32, 32, 128) 512 ['conv2d_2[0][0]'] \n",
202
- " rmalization) \n",
203
- " \n",
204
- " re_lu_3 (ReLU) (None, 32, 32, 128) 0 ['batch_normalization_2[0][0]'] \n",
205
- " \n",
206
- " up_sampling2d_3 (UpSampling2D) (None, 64, 64, 128) 0 ['re_lu_3[0][0]'] \n",
207
- " \n",
208
- " conv2d_3 (Conv2D) (None, 64, 64, 64) 73728 ['up_sampling2d_3[0][0]'] \n",
209
- " \n",
210
- " batch_normalization_3 (BatchNo (None, 64, 64, 64) 256 ['conv2d_3[0][0]'] \n",
211
- " rmalization) \n",
212
- " \n",
213
- " re_lu_4 (ReLU) (None, 64, 64, 64) 0 ['batch_normalization_3[0][0]'] \n",
214
- " \n",
215
- " conv2d_4 (Conv2D) (None, 64, 64, 3) 1728 ['re_lu_4[0][0]'] \n",
216
- " \n",
217
- " activation (Activation) (None, 64, 64, 3) 0 ['conv2d_4[0][0]'] \n",
218
- " \n",
219
- "==================================================================================================\n",
220
- "Total params: 10,270,400\n",
221
- "Trainable params: 10,268,480\n",
222
- "Non-trainable params: 1,920\n",
223
- "__________________________________________________________________________________________________\n"
224
- ]
225
- }
226
- ],
227
- "source": [
228
- "generator = build_stage1_generator()\n",
229
- "generator.summary()"
230
- ]
231
- },
232
- {
233
- "cell_type": "markdown",
234
- "id": "a14d9d1c",
235
- "metadata": {},
236
- "source": [
237
- "### Stage 1 Discriminator Network"
238
- ]
239
- },
240
- {
241
- "cell_type": "code",
242
- "execution_count": 9,
243
- "id": "32b436ac",
244
- "metadata": {},
245
- "outputs": [],
246
- "source": [
247
- "def ConvBlock(x, num_kernels, kernel_size=(4,4), strides=2, activation=True):\n",
248
- " \"\"\"A ConvBlock with a Conv2D, BatchNormalization and LeakyReLU activation.\n",
249
- "\n",
250
- " Args:\n",
251
- " x: The preceding layer as input.\n",
252
- " num_kernels: Number of kernels for the Conv2D layer.\n",
253
- "\n",
254
- " Returns:\n",
255
- " x: The final activation layer after the ConvBlock block.\n",
256
- " \"\"\"\n",
257
- " x = Conv2D(num_kernels, kernel_size=kernel_size, padding='same', strides=strides, use_bias=False,\n",
258
- " kernel_initializer='he_uniform')(x)\n",
259
- " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
260
- " \n",
261
- " if activation:\n",
262
- " x = LeakyReLU(alpha=0.2)(x)\n",
263
- " return x\n",
264
- "\n",
265
- "\n",
266
- "def build_embedding_compressor():\n",
267
- " \"\"\"Build embedding compressor model\n",
268
- " \"\"\"\n",
269
- " input_layer1 = Input(shape=(1024,)) \n",
270
- " x = Dense(128)(input_layer1)\n",
271
- " x = ReLU()(x)\n",
272
- "\n",
273
- " model = Model(inputs=[input_layer1], outputs=[x])\n",
274
- " return model\n",
275
- "\n",
276
- "# the discriminator is fed with two inputs, the feature from Generator and the text embedding\n",
277
- "def build_stage1_discriminator():\n",
278
- " \"\"\"Builds the Stage 1 Discriminator that uses the 64x64 resolution images from the generator\n",
279
- " and the compressed and spatially replicated embedding.\n",
280
- "\n",
281
- " Returns:\n",
282
- " Stage 1 Discriminator Model for StackGAN.\n",
283
- " \"\"\"\n",
284
- " input_layer1 = Input(shape=(64, 64, 3)) \n",
285
- "\n",
286
- " x = Conv2D(64, kernel_size=(4,4), strides=2, padding='same', use_bias=False,\n",
287
- " kernel_initializer='he_uniform')(input_layer1)\n",
288
- " x = LeakyReLU(alpha=0.2)(x)\n",
289
- "\n",
290
- " x = ConvBlock(x, 128)\n",
291
- " x = ConvBlock(x, 256)\n",
292
- " x = ConvBlock(x, 512)\n",
293
- "\n",
294
- " # Obtain the compressed and spatially replicated text embedding\n",
295
- " input_layer2 = Input(shape=(4, 4, 128)) #2nd input to discriminator, text embedding\n",
296
- " concat = concatenate([x, input_layer2])\n",
297
- "\n",
298
- " x1 = Conv2D(512, kernel_size=(1,1), padding='same', strides=1, use_bias=False,\n",
299
- " kernel_initializer='he_uniform')(concat)\n",
300
- " x1 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
301
- " x1 = LeakyReLU(alpha=0.2)(x)\n",
302
- "\n",
303
- " # Flatten and add a FC layer to predict.\n",
304
- " x1 = Flatten()(x1)\n",
305
- " x1 = Dense(1)(x1)\n",
306
- " x1 = Activation('sigmoid')(x1)\n",
307
- "\n",
308
- " stage1_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x1]) \n",
309
- " return stage1_dis"
310
- ]
311
- },
312
- {
313
- "cell_type": "code",
314
- "execution_count": 10,
315
- "id": "98090438",
316
- "metadata": {},
317
- "outputs": [
318
- {
319
- "name": "stdout",
320
- "output_type": "stream",
321
- "text": [
322
- "Model: \"model_1\"\n",
323
- "__________________________________________________________________________________________________\n",
324
- " Layer (type) Output Shape Param # Connected to \n",
325
- "==================================================================================================\n",
326
- " input_5 (InputLayer) [(None, 64, 64, 3)] 0 [] \n",
327
- " \n",
328
- " conv2d_9 (Conv2D) (None, 32, 32, 64) 3072 ['input_5[0][0]'] \n",
329
- " \n",
330
- " leaky_re_lu_5 (LeakyReLU) (None, 32, 32, 64) 0 ['conv2d_9[0][0]'] \n",
331
- " \n",
332
- " conv2d_10 (Conv2D) (None, 16, 16, 128) 131072 ['leaky_re_lu_5[0][0]'] \n",
333
- " \n",
334
- " batch_normalization_7 (BatchNo (None, 16, 16, 128) 512 ['conv2d_10[0][0]'] \n",
335
- " rmalization) \n",
336
- " \n",
337
- " leaky_re_lu_6 (LeakyReLU) (None, 16, 16, 128) 0 ['batch_normalization_7[0][0]'] \n",
338
- " \n",
339
- " conv2d_11 (Conv2D) (None, 8, 8, 256) 524288 ['leaky_re_lu_6[0][0]'] \n",
340
- " \n",
341
- " batch_normalization_8 (BatchNo (None, 8, 8, 256) 1024 ['conv2d_11[0][0]'] \n",
342
- " rmalization) \n",
343
- " \n",
344
- " leaky_re_lu_7 (LeakyReLU) (None, 8, 8, 256) 0 ['batch_normalization_8[0][0]'] \n",
345
- " \n",
346
- " conv2d_12 (Conv2D) (None, 4, 4, 512) 2097152 ['leaky_re_lu_7[0][0]'] \n",
347
- " \n",
348
- " batch_normalization_9 (BatchNo (None, 4, 4, 512) 2048 ['conv2d_12[0][0]'] \n",
349
- " rmalization) \n",
350
- " \n",
351
- " leaky_re_lu_8 (LeakyReLU) (None, 4, 4, 512) 0 ['batch_normalization_9[0][0]'] \n",
352
- " \n",
353
- " leaky_re_lu_9 (LeakyReLU) (None, 4, 4, 512) 0 ['leaky_re_lu_8[0][0]'] \n",
354
- " \n",
355
- " flatten (Flatten) (None, 8192) 0 ['leaky_re_lu_9[0][0]'] \n",
356
- " \n",
357
- " dense_2 (Dense) (None, 1) 8193 ['flatten[0][0]'] \n",
358
- " \n",
359
- " input_6 (InputLayer) [(None, 4, 4, 128)] 0 [] \n",
360
- " \n",
361
- " activation_1 (Activation) (None, 1) 0 ['dense_2[0][0]'] \n",
362
- " \n",
363
- "==================================================================================================\n",
364
- "Total params: 2,767,361\n",
365
- "Trainable params: 2,765,569\n",
366
- "Non-trainable params: 1,792\n",
367
- "__________________________________________________________________________________________________\n"
368
- ]
369
- }
370
- ],
371
- "source": [
372
- "discriminator = build_stage1_discriminator()\n",
373
- "discriminator.summary()"
374
- ]
375
- },
376
- {
377
- "cell_type": "markdown",
378
- "id": "cdc2a75a",
379
- "metadata": {},
380
- "source": [
381
- "### Stage 1 Adversarial Model (Building a GAN)"
382
- ]
383
- },
384
- {
385
- "cell_type": "code",
386
- "execution_count": 11,
387
- "id": "5d0678f7",
388
- "metadata": {},
389
- "outputs": [],
390
- "source": [
391
- "# Building GAN with Generator and Discriminator\n",
392
- "\n",
393
- "def build_adversarial(generator_model, discriminator_model):\n",
394
- " \"\"\"Stage 1 Adversarial model.\n",
395
- "\n",
396
- " Args:\n",
397
- " generator_model: Stage 1 Generator Model\n",
398
- " discriminator_model: Stage 1 Discriminator Model\n",
399
- "\n",
400
- " Returns:\n",
401
- " Adversarial Model.\n",
402
- " \"\"\"\n",
403
- " input_layer1 = Input(shape=(1024,)) \n",
404
- " input_layer2 = Input(shape=(100,)) \n",
405
- " input_layer3 = Input(shape=(4, 4, 128)) \n",
406
- "\n",
407
- " x, ca = generator_model([input_layer1, input_layer2]) #text,noise\n",
408
- "\n",
409
- " discriminator_model.trainable = False \n",
410
- "\n",
411
- " probabilities = discriminator_model([x, input_layer3]) \n",
412
- " adversarial_model = Model(inputs=[input_layer1, input_layer2, input_layer3], outputs=[probabilities, ca])\n",
413
- " return adversarial_model"
414
- ]
415
- },
416
- {
417
- "cell_type": "code",
418
- "execution_count": 12,
419
- "id": "bd351c9d",
420
- "metadata": {},
421
- "outputs": [
422
- {
423
- "name": "stdout",
424
- "output_type": "stream",
425
- "text": [
426
- "Model: \"model_2\"\n",
427
- "__________________________________________________________________________________________________\n",
428
- " Layer (type) Output Shape Param # Connected to \n",
429
- "==================================================================================================\n",
430
- " input_7 (InputLayer) [(None, 1024)] 0 [] \n",
431
- " \n",
432
- " input_8 (InputLayer) [(None, 100)] 0 [] \n",
433
- " \n",
434
- " model (Functional) [(None, 64, 64, 3), 10270400 ['input_7[0][0]', \n",
435
- " (None, 256)] 'input_8[0][0]'] \n",
436
- " \n",
437
- " input_9 (InputLayer) [(None, 4, 4, 128)] 0 [] \n",
438
- " \n",
439
- " model_1 (Functional) (None, 1) 2767361 ['model[0][0]', \n",
440
- " 'input_9[0][0]'] \n",
441
- " \n",
442
- "==================================================================================================\n",
443
- "Total params: 13,037,761\n",
444
- "Trainable params: 10,268,480\n",
445
- "Non-trainable params: 2,769,281\n",
446
- "__________________________________________________________________________________________________\n"
447
- ]
448
- }
449
- ],
450
- "source": [
451
- "ganstage1 = build_adversarial(generator, discriminator)\n",
452
- "ganstage1.summary()"
453
- ]
454
- },
455
- {
456
- "cell_type": "markdown",
457
- "id": "adf70416",
458
- "metadata": {},
459
- "source": [
460
- "### Train Utilities"
461
- ]
462
- },
463
- {
464
- "cell_type": "code",
465
- "execution_count": 13,
466
- "id": "730c9e8a",
467
- "metadata": {},
468
- "outputs": [],
469
- "source": [
470
- "def checkpoint_prefix():\n",
471
- " checkpoint_dir = './training_checkpoints'\n",
472
- " checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')\n",
473
- "\n",
474
- " return checkpoint_prefix\n",
475
- "\n",
476
- "def adversarial_loss(y_true, y_pred):\n",
477
- " mean = y_pred[:, :128]\n",
478
- " ls = y_pred[:, 128:]\n",
479
- " loss = -ls + 0.5 * (-1 + tf.math.exp(2.0 * ls) + tf.math.square(mean))\n",
480
- " loss = K.mean(loss)\n",
481
- " return loss\n",
482
- "\n",
483
- "def normalize(input_image, real_image):\n",
484
- " input_image = (input_image / 127.5) - 1\n",
485
- " real_image = (real_image / 127.5) - 1\n",
486
- "\n",
487
- " return input_image, real_image\n",
488
- "\n",
489
- "def load_class_ids_filenames(class_id_path, filename_path):\n",
490
- " with open(class_id_path, 'rb') as file:\n",
491
- " class_id = pickle.load(file, encoding='latin1')\n",
492
- "\n",
493
- " with open(filename_path, 'rb') as file:\n",
494
- " filename = pickle.load(file, encoding='latin1')\n",
495
- "\n",
496
- " return class_id, filename\n",
497
- "\n",
498
- "def load_text_embeddings(text_embeddings):\n",
499
- " with open(text_embeddings, 'rb') as file:\n",
500
- " embeds = pickle.load(file, encoding='latin1')\n",
501
- " embeds = np.array(embeds)\n",
502
- "\n",
503
- " return embeds\n",
504
- "\n",
505
- "def load_bbox(data_path):\n",
506
- " bbox_path = data_path + '/bounding_boxes.txt'\n",
507
- " image_path = data_path + '/images.txt'\n",
508
- " bbox_df = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)\n",
509
- " filename_df = pd.read_csv(image_path, delim_whitespace=True, header=None)\n",
510
- "\n",
511
- " filenames = filename_df[1].tolist()\n",
512
- " bbox_dict = {i[:-4]:[] for i in filenames[:2]}\n",
513
- "\n",
514
- " for i in range(0, len(filenames)):\n",
515
- " bbox = bbox_df.iloc[i][1:].tolist()\n",
516
- " dict_key = filenames[i][:-4]\n",
517
- " bbox_dict[dict_key] = bbox\n",
518
- "\n",
519
- " return bbox_dict\n",
520
- "\n",
521
- "def load_images(image_path, bounding_box, size):\n",
522
- " \"\"\"Crops the image to the bounding box and then resizes it.\n",
523
- " \"\"\"\n",
524
- " image = Image.open(image_path).convert('RGB')\n",
525
- " w, h = image.size\n",
526
- " if bounding_box is not None:\n",
527
- " r = int(np.maximum(bounding_box[2], bounding_box[3]) * 0.75)\n",
528
- " c_x = int((bounding_box[0] + bounding_box[2]) / 2)\n",
529
- " c_y = int((bounding_box[1] + bounding_box[3]) / 2)\n",
530
- " y1 = np.maximum(0, c_y - r)\n",
531
- " y2 = np.minimum(h, c_y + r)\n",
532
- " x1 = np.maximum(0, c_x - r)\n",
533
- " x2 = np.minimum(w, c_x + r)\n",
534
- " image = image.crop([x1, y1, x2, y2])\n",
535
- "\n",
536
- " image = image.resize(size, PIL.Image.BILINEAR)\n",
537
- " return image\n",
538
- "\n",
539
- "def load_data(filename_path, class_id_path, dataset_path, embeddings_path, size):\n",
540
- " \"\"\"Loads the Dataset.\n",
541
- " \"\"\"\n",
542
- " data_dir = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds\"\n",
543
- " train_dir = data_dir + \"/train\"\n",
544
- " test_dir = data_dir + \"/test\"\n",
545
- " embeddings_path_train = train_dir + \"/char-CNN-RNN-embeddings.pickle\"\n",
546
- " embeddings_path_test = test_dir + \"/char-CNN-RNN-embeddings.pickle\"\n",
547
- " filename_path_train = train_dir + \"/filenames.pickle\"\n",
548
- " filename_path_test = test_dir + \"/filenames.pickle\"\n",
549
- " class_id_path_train = train_dir + \"/class_info.pickle\"\n",
550
- " class_id_path_test = test_dir + \"/class_info.pickle\"\n",
551
- " dataset_path = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011\"\n",
552
- " class_id, filenames = load_class_ids_filenames(class_id_path, filename_path)\n",
553
- " embeddings = load_text_embeddings(embeddings_path)\n",
554
- " bbox_dict = load_bbox(dataset_path)\n",
555
- "\n",
556
- " x, y, embeds = [], [], []\n",
557
- "\n",
558
- " for i, filename in enumerate(filenames):\n",
559
- " bbox = bbox_dict[filename]\n",
560
- "\n",
561
- " try:\n",
562
- " image_path = f'{dataset_path}/images/{filename}.jpg'\n",
563
- " image = load_images(image_path, bbox, size)\n",
564
- " e = embeddings[i, :, :]\n",
565
- " embed_index = np.random.randint(0, e.shape[0] - 1)\n",
566
- " embed = e[embed_index, :]\n",
567
- "\n",
568
- " x.append(np.array(image))\n",
569
- " y.append(class_id[i])\n",
570
- " embeds.append(embed)\n",
571
- "\n",
572
- " except Exception as e:\n",
573
- " print(f'{e}')\n",
574
- " \n",
575
- " x = np.array(x)\n",
576
- " y = np.array(y)\n",
577
- " embeds = np.array(embeds)\n",
578
- " \n",
579
- " return x, y, embeds\n",
580
- "\n",
581
- "def save_image(file, save_path):\n",
582
- " \"\"\"Saves the image at the specified file path.\n",
583
- " \"\"\"\n",
584
- " image = plt.figure()\n",
585
- " ax = image.add_subplot(1,1,1)\n",
586
- " ax.imshow(file)\n",
587
- " ax.axis(\"off\")\n",
588
- " plt.savefig(save_path)"
589
- ]
590
- },
591
- {
592
- "cell_type": "code",
593
- "execution_count": 28,
594
- "id": "697f1dc6",
595
- "metadata": {},
596
- "outputs": [],
597
- "source": [
598
- "############################################################\n",
599
- "# StackGAN class\n",
600
- "############################################################\n",
601
- "\n",
602
- "class StackGanStage1(object):\n",
603
- " \"\"\"StackGAN Stage 1 class.\"\"\"\n",
604
- "\n",
605
- " data_dir = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds\"\n",
606
- " train_dir = data_dir + \"/train\"\n",
607
- " test_dir = data_dir + \"/test\"\n",
608
- " embeddings_path_train = train_dir + \"/char-CNN-RNN-embeddings.pickle\"\n",
609
- " embeddings_path_test = test_dir + \"/char-CNN-RNN-embeddings.pickle\"\n",
610
- " filename_path_train = train_dir + \"/filenames.pickle\"\n",
611
- " filename_path_test = test_dir + \"/filenames.pickle\"\n",
612
- " class_id_path_train = train_dir + \"/class_info.pickle\"\n",
613
- " class_id_path_test = test_dir + \"/class_info.pickle\"\n",
614
- " dataset_path = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011\"\n",
615
- " def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage1_generator_lr=0.0002, stage1_discriminator_lr=0.0002):\n",
616
- " self.epochs = epochs\n",
617
- " self.z_dim = z_dim\n",
618
- " self.enable_function = enable_function\n",
619
- " self.stage1_generator_lr = stage1_generator_lr\n",
620
- " self.stage1_discriminator_lr = stage1_discriminator_lr\n",
621
- " self.image_size = 64\n",
622
- " self.conditioning_dim = 128\n",
623
- " self.batch_size = batch_size\n",
624
- " \n",
625
- " self.stage1_generator_optimizer = Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)\n",
626
- " self.stage1_discriminator_optimizer = Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)\n",
627
- " \n",
628
- " self.stage1_generator = build_stage1_generator()\n",
629
- " self.stage1_generator.compile(loss='mse', optimizer=self.stage1_generator_optimizer)\n",
630
- "\n",
631
- " self.stage1_discriminator = build_stage1_discriminator()\n",
632
- " self.stage1_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage1_discriminator_optimizer)\n",
633
- "\n",
634
- " self.ca_network = build_ca_network()\n",
635
- " self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')\n",
636
- "\n",
637
- " self.embedding_compressor = build_embedding_compressor()\n",
638
- " self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')\n",
639
- "\n",
640
- " self.stage1_adversarial = build_adversarial(self.stage1_generator, self.stage1_discriminator)\n",
641
- " self.stage1_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage1_generator_optimizer)\n",
642
- "\n",
643
- " self.checkpoint1 = tf.train.Checkpoint(\n",
644
- " generator_optimizer=self.stage1_generator_optimizer,\n",
645
- " discriminator_optimizer=self.stage1_discriminator_optimizer,\n",
646
- " generator=self.stage1_generator,\n",
647
- " discriminator=self.stage1_discriminator)\n",
648
- "\n",
649
- " def visualize_stage1(self):\n",
650
- " \"\"\"Running Tensorboard visualizations.\n",
651
- " \"\"\"\n",
652
- " tb = TensorBoard(log_dir=\"logs/\".format(time.time()))\n",
653
- " tb.set_model(self.stage1_generator)\n",
654
- " tb.set_model(self.stage1_discriminator)\n",
655
- " tb.set_model(self.ca_network)\n",
656
- " tb.set_model(self.embedding_compressor)\n",
657
- "\n",
658
- " def train_stage1(self):\n",
659
- " \"\"\"Trains the stage1 StackGAN.\n",
660
- " \"\"\"\n",
661
- " x_train, y_train, train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n",
662
- " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))\n",
663
- "\n",
664
- " x_test, y_test, test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n",
665
- " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))\n",
666
- "\n",
667
- " real = np.ones((self.batch_size, 1), dtype='float') * 0.9\n",
668
- " fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1\n",
669
- "\n",
670
- " for epoch in range(self.epochs):\n",
671
- " print(f'Epoch: {epoch}')\n",
672
- "\n",
673
- " gen_loss = []\n",
674
- " dis_loss = []\n",
675
- "\n",
676
- " num_batches = int(x_train.shape[0] / self.batch_size)\n",
677
- "\n",
678
- " for i in range(num_batches):\n",
679
- "\n",
680
- " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n",
681
- " embedding_text = train_embeds[i * self.batch_size:(i + 1) * self.batch_size]\n",
682
- " compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)\n",
683
- " compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, 128))\n",
684
- " compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))\n",
685
- "\n",
686
- " image_batch = x_train[i * self.batch_size:(i+1) * self.batch_size]\n",
687
- " image_batch = (image_batch - 127.5) / 127.5\n",
688
- "\n",
689
- " gen_images, _ = self.stage1_generator.predict([embedding_text, latent_space])\n",
690
- "\n",
691
- " discriminator_loss = self.stage1_discriminator.train_on_batch([image_batch, compressed_embedding], \n",
692
- " np.reshape(real, (self.batch_size, 1)))\n",
693
- "\n",
694
- " discriminator_loss_gen = self.stage1_discriminator.train_on_batch([gen_images, compressed_embedding],\n",
695
- " np.reshape(fake, (self.batch_size, 1)))\n",
696
- "\n",
697
- " discriminator_loss_wrong = self.stage1_discriminator.train_on_batch([gen_images[: self.batch_size-1], compressed_embedding[1:]], \n",
698
- " np.reshape(fake[1:], (self.batch_size-1, 1)))\n",
699
- "\n",
700
- "# Discriminator loss\n",
701
- " d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_wrong))\n",
702
- " dis_loss.append(d_loss)\n",
703
- "\n",
704
- " print(f'Discriminator Loss: {d_loss}')\n",
705
- "\n",
706
- " # Generator loss\n",
707
- " g_loss = self.stage1_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],\n",
708
- " [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])\n",
709
- "\n",
710
- " print(f'Generator Loss: {g_loss}')\n",
711
- " gen_loss.append(g_loss)\n",
712
- "\n",
713
- " if epoch % 5 == 0:\n",
714
- " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n",
715
- " embedding_batch = test_embeds[0 : self.batch_size]\n",
716
- " gen_images, _ = self.stage1_generator.predict_on_batch([embedding_batch, latent_space])\n",
717
- "\n",
718
- " for i, image in enumerate(gen_images[:10]):\n",
719
- " save_image(image, f'test/gen_1_{epoch}_{i}')\n",
720
- "\n",
721
- " if epoch % 25 == 0:\n",
722
- " self.stage1_generator.save_weights('weights/stage1_gen.h5')\n",
723
- " self.stage1_discriminator.save_weights(\"weights/stage1_disc.h5\")\n",
724
- " self.ca_network.save_weights('weights/stage1_ca.h5')\n",
725
- " self.embedding_compressor.save_weights('weights/stage1_embco.h5')\n",
726
- " self.stage1_adversarial.save_weights('weights/stage1_adv.h5') \n",
727
- "\n",
728
- " self.stage1_generator.save_weights('weights/stage1_gen.h5')\n",
729
- " self.stage1_discriminator.save_weights(\"weights/stage1_disc.h5\")"
730
- ]
731
- },
732
- {
733
- "cell_type": "code",
734
- "execution_count": null,
735
- "id": "517037ac",
736
- "metadata": {},
737
- "outputs": [],
738
- "source": [
739
- "stage1 = StackGanStage1()\n",
740
- "stage1.train_stage1()"
741
- ]
742
- },
743
- {
744
- "cell_type": "markdown",
745
- "id": "7d85b9f2",
746
- "metadata": {},
747
- "source": [
748
- "### Check test folder for gernerated images from Stage1 Generator\n",
749
- "### Let's Implement Stage 2 Generator"
750
- ]
751
- },
752
- {
753
- "cell_type": "code",
754
- "execution_count": 29,
755
- "id": "2e45c731",
756
- "metadata": {},
757
- "outputs": [],
758
- "source": [
759
- "############################################################\n",
760
- "# Stage 2 Generator Network\n",
761
- "############################################################\n",
762
- "\n",
763
- "def concat_along_dims(inputs):\n",
764
- " \"\"\"Joins the conditioned text with the encoded image along the dimensions.\n",
765
- "\n",
766
- " Args:\n",
767
- " inputs: consisting of conditioned text and encoded images as [c,x].\n",
768
- "\n",
769
- " Returns:\n",
770
- " Joint block along the dimensions.\n",
771
- " \"\"\"\n",
772
- " c = inputs[0]\n",
773
- " x = inputs[1]\n",
774
- "\n",
775
- " c = K.expand_dims(c, axis=1)\n",
776
- " c = K.expand_dims(c, axis=1)\n",
777
- " c = K.tile(c, [1, 16, 16, 1])\n",
778
- " return K.concatenate([c, x], axis = 3)\n",
779
- "\n",
780
- "def residual_block(input):\n",
781
- " \"\"\"Residual block with plain identity connections.\n",
782
- "\n",
783
- " Args:\n",
784
- " inputs: input layer or an encoded layer\n",
785
- "\n",
786
- " Returns:\n",
787
- " Layer with computed identity mapping.\n",
788
- " \"\"\"\n",
789
- " x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,\n",
790
- " kernel_initializer='he_uniform')(input)\n",
791
- " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
792
- " x = ReLU()(x)\n",
793
- " \n",
794
- " x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,\n",
795
- " kernel_initializer='he_uniform')(x)\n",
796
- " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
797
- " \n",
798
- " x = add([x, input])\n",
799
- " x = ReLU()(x)\n",
800
- "\n",
801
- " return x\n",
802
- "\n",
803
- "def build_stage2_generator():\n",
804
- " \"\"\"Build the Stage 2 Generator Network using the conditioning text and images from stage 1.\n",
805
- "\n",
806
- " Returns:\n",
807
- " Stage 2 Generator Model for StackGAN.\n",
808
- " \"\"\"\n",
809
- " input_layer1 = Input(shape=(1024,))\n",
810
- " input_images = Input(shape=(64, 64, 3))\n",
811
- "\n",
812
- " # Conditioning Augmentation\n",
813
- " ca = Dense(256)(input_layer1)\n",
814
- " mls = LeakyReLU(alpha=0.2)(ca)\n",
815
- " c = Lambda(conditioning_augmentation)(mls)\n",
816
- "\n",
817
- " # Downsampling block\n",
818
- " x = ZeroPadding2D(padding=(1,1))(input_images)\n",
819
- " x = Conv2D(128, kernel_size=(3,3), strides=1, use_bias=False,\n",
820
- " kernel_initializer='he_uniform')(x)\n",
821
- " x = ReLU()(x)\n",
822
- "\n",
823
- " x = ZeroPadding2D(padding=(1,1))(x)\n",
824
- " x = Conv2D(256, kernel_size=(4,4), strides=2, use_bias=False,\n",
825
- " kernel_initializer='he_uniform')(x)\n",
826
- " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
827
- " x = ReLU()(x)\n",
828
- "\n",
829
- " x = ZeroPadding2D(padding=(1,1))(x)\n",
830
- " x = Conv2D(512, kernel_size=(4,4), strides=2, use_bias=False,\n",
831
- " kernel_initializer='he_uniform')(x)\n",
832
- " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
833
- " x = ReLU()(x)\n",
834
- "\n",
835
- " # Concatenate text conditioning block with the encoded image\n",
836
- " concat = concat_along_dims([c, x])\n",
837
- "\n",
838
- " # Residual Blocks\n",
839
- " x = ZeroPadding2D(padding=(1,1))(concat)\n",
840
- " x = Conv2D(512, kernel_size=(3,3), use_bias=False, kernel_initializer='he_uniform')(x)\n",
841
- " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
842
- " x = ReLU()(x)\n",
843
- "\n",
844
- " x = residual_block(x)\n",
845
- " x = residual_block(x)\n",
846
- " x = residual_block(x)\n",
847
- " x = residual_block(x)\n",
848
- " \n",
849
- " # Upsampling Blocks\n",
850
- " x = UpSamplingBlock(x, 512)\n",
851
- " x = UpSamplingBlock(x, 256)\n",
852
- " x = UpSamplingBlock(x, 128)\n",
853
- " x = UpSamplingBlock(x, 64)\n",
854
- "\n",
855
- " x = Conv2D(3, kernel_size=(3,3), padding='same', use_bias=False, kernel_initializer='he_uniform')(x)\n",
856
- " x = Activation('tanh')(x)\n",
857
- " \n",
858
- " stage2_gen = Model(inputs=[input_layer1, input_images], outputs=[x, mls])\n",
859
- " return stage2_gen"
860
- ]
861
- },
862
- {
863
- "cell_type": "code",
864
- "execution_count": 30,
865
- "id": "76c876db",
866
- "metadata": {},
867
- "outputs": [
868
- {
869
- "name": "stdout",
870
- "output_type": "stream",
871
- "text": [
872
- "Model: \"model_3\"\n",
873
- "__________________________________________________________________________________________________\n",
874
- " Layer (type) Output Shape Param # Connected to \n",
875
- "==================================================================================================\n",
876
- " input_11 (InputLayer) [(None, 64, 64, 3)] 0 [] \n",
877
- " \n",
878
- " zero_padding2d (ZeroPadding2D) (None, 66, 66, 3) 0 ['input_11[0][0]'] \n",
879
- " \n",
880
- " conv2d_14 (Conv2D) (None, 64, 64, 128) 3456 ['zero_padding2d[0][0]'] \n",
881
- " \n",
882
- " re_lu_5 (ReLU) (None, 64, 64, 128) 0 ['conv2d_14[0][0]'] \n",
883
- " \n",
884
- " zero_padding2d_1 (ZeroPadding2 (None, 66, 66, 128) 0 ['re_lu_5[0][0]'] \n",
885
- " D) \n",
886
- " \n",
887
- " input_10 (InputLayer) [(None, 1024)] 0 [] \n",
888
- " \n",
889
- " conv2d_15 (Conv2D) (None, 32, 32, 256) 524288 ['zero_padding2d_1[0][0]'] \n",
890
- " \n",
891
- " dense_3 (Dense) (None, 256) 262400 ['input_10[0][0]'] \n",
892
- " \n",
893
- " batch_normalization_11 (BatchN (None, 32, 32, 256) 1024 ['conv2d_15[0][0]'] \n",
894
- " ormalization) \n",
895
- " \n",
896
- " leaky_re_lu_10 (LeakyReLU) (None, 256) 0 ['dense_3[0][0]'] \n",
897
- " \n",
898
- " re_lu_6 (ReLU) (None, 32, 32, 256) 0 ['batch_normalization_11[0][0]'] \n",
899
- " \n",
900
- " lambda_1 (Lambda) (None, 128) 0 ['leaky_re_lu_10[0][0]'] \n",
901
- " \n",
902
- " zero_padding2d_2 (ZeroPadding2 (None, 34, 34, 256) 0 ['re_lu_6[0][0]'] \n",
903
- " D) \n",
904
- " \n",
905
- " tf.expand_dims (TFOpLambda) (None, 1, 128) 0 ['lambda_1[0][0]'] \n",
906
- " \n",
907
- " conv2d_16 (Conv2D) (None, 16, 16, 512) 2097152 ['zero_padding2d_2[0][0]'] \n",
908
- " \n",
909
- " tf.expand_dims_1 (TFOpLambda) (None, 1, 1, 128) 0 ['tf.expand_dims[0][0]'] \n",
910
- " \n",
911
- " batch_normalization_12 (BatchN (None, 16, 16, 512) 2048 ['conv2d_16[0][0]'] \n",
912
- " ormalization) \n",
913
- " \n",
914
- " tf.tile (TFOpLambda) (None, 16, 16, 128) 0 ['tf.expand_dims_1[0][0]'] \n",
915
- " \n",
916
- " re_lu_7 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_12[0][0]'] \n",
917
- " \n",
918
- " tf.concat (TFOpLambda) (None, 16, 16, 640) 0 ['tf.tile[0][0]', \n",
919
- " 're_lu_7[0][0]'] \n",
920
- " \n",
921
- " zero_padding2d_3 (ZeroPadding2 (None, 18, 18, 640) 0 ['tf.concat[0][0]'] \n",
922
- " D) \n",
923
- " \n",
924
- " conv2d_17 (Conv2D) (None, 16, 16, 512) 2949120 ['zero_padding2d_3[0][0]'] \n",
925
- " \n",
926
- " batch_normalization_13 (BatchN (None, 16, 16, 512) 2048 ['conv2d_17[0][0]'] \n",
927
- " ormalization) \n",
928
- " \n",
929
- " re_lu_8 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_13[0][0]'] \n",
930
- " \n",
931
- " conv2d_18 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_8[0][0]'] \n",
932
- " \n",
933
- " batch_normalization_14 (BatchN (None, 16, 16, 512) 2048 ['conv2d_18[0][0]'] \n",
934
- " ormalization) \n",
935
- " \n",
936
- " re_lu_9 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_14[0][0]'] \n",
937
- " \n",
938
- " conv2d_19 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_9[0][0]'] \n",
939
- " \n",
940
- " batch_normalization_15 (BatchN (None, 16, 16, 512) 2048 ['conv2d_19[0][0]'] \n",
941
- " ormalization) \n",
942
- " \n",
943
- " add (Add) (None, 16, 16, 512) 0 ['batch_normalization_15[0][0]', \n",
944
- " 're_lu_8[0][0]'] \n",
945
- " \n",
946
- " re_lu_10 (ReLU) (None, 16, 16, 512) 0 ['add[0][0]'] \n",
947
- " \n",
948
- " conv2d_20 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_10[0][0]'] \n",
949
- " \n",
950
- " batch_normalization_16 (BatchN (None, 16, 16, 512) 2048 ['conv2d_20[0][0]'] \n",
951
- " ormalization) \n",
952
- " \n",
953
- " re_lu_11 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_16[0][0]'] \n",
954
- " \n",
955
- " conv2d_21 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_11[0][0]'] \n"
956
- ]
957
- },
958
- {
959
- "name": "stdout",
960
- "output_type": "stream",
961
- "text": [
962
- " \n",
963
- " batch_normalization_17 (BatchN (None, 16, 16, 512) 2048 ['conv2d_21[0][0]'] \n",
964
- " ormalization) \n",
965
- " \n",
966
- " add_1 (Add) (None, 16, 16, 512) 0 ['batch_normalization_17[0][0]', \n",
967
- " 're_lu_10[0][0]'] \n",
968
- " \n",
969
- " re_lu_12 (ReLU) (None, 16, 16, 512) 0 ['add_1[0][0]'] \n",
970
- " \n",
971
- " conv2d_22 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_12[0][0]'] \n",
972
- " \n",
973
- " batch_normalization_18 (BatchN (None, 16, 16, 512) 2048 ['conv2d_22[0][0]'] \n",
974
- " ormalization) \n",
975
- " \n",
976
- " re_lu_13 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_18[0][0]'] \n",
977
- " \n",
978
- " conv2d_23 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_13[0][0]'] \n",
979
- " \n",
980
- " batch_normalization_19 (BatchN (None, 16, 16, 512) 2048 ['conv2d_23[0][0]'] \n",
981
- " ormalization) \n",
982
- " \n",
983
- " add_2 (Add) (None, 16, 16, 512) 0 ['batch_normalization_19[0][0]', \n",
984
- " 're_lu_12[0][0]'] \n",
985
- " \n",
986
- " re_lu_14 (ReLU) (None, 16, 16, 512) 0 ['add_2[0][0]'] \n",
987
- " \n",
988
- " conv2d_24 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_14[0][0]'] \n",
989
- " \n",
990
- " batch_normalization_20 (BatchN (None, 16, 16, 512) 2048 ['conv2d_24[0][0]'] \n",
991
- " ormalization) \n",
992
- " \n",
993
- " re_lu_15 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_20[0][0]'] \n",
994
- " \n",
995
- " conv2d_25 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_15[0][0]'] \n",
996
- " \n",
997
- " batch_normalization_21 (BatchN (None, 16, 16, 512) 2048 ['conv2d_25[0][0]'] \n",
998
- " ormalization) \n",
999
- " \n",
1000
- " add_3 (Add) (None, 16, 16, 512) 0 ['batch_normalization_21[0][0]', \n",
1001
- " 're_lu_14[0][0]'] \n",
1002
- " \n",
1003
- " re_lu_16 (ReLU) (None, 16, 16, 512) 0 ['add_3[0][0]'] \n",
1004
- " \n",
1005
- " up_sampling2d_4 (UpSampling2D) (None, 32, 32, 512) 0 ['re_lu_16[0][0]'] \n",
1006
- " \n",
1007
- " conv2d_26 (Conv2D) (None, 32, 32, 512) 2359296 ['up_sampling2d_4[0][0]'] \n",
1008
- " \n",
1009
- " batch_normalization_22 (BatchN (None, 32, 32, 512) 2048 ['conv2d_26[0][0]'] \n",
1010
- " ormalization) \n",
1011
- " \n",
1012
- " re_lu_17 (ReLU) (None, 32, 32, 512) 0 ['batch_normalization_22[0][0]'] \n",
1013
- " \n",
1014
- " up_sampling2d_5 (UpSampling2D) (None, 64, 64, 512) 0 ['re_lu_17[0][0]'] \n",
1015
- " \n",
1016
- " conv2d_27 (Conv2D) (None, 64, 64, 256) 1179648 ['up_sampling2d_5[0][0]'] \n",
1017
- " \n",
1018
- " batch_normalization_23 (BatchN (None, 64, 64, 256) 1024 ['conv2d_27[0][0]'] \n",
1019
- " ormalization) \n",
1020
- " \n",
1021
- " re_lu_18 (ReLU) (None, 64, 64, 256) 0 ['batch_normalization_23[0][0]'] \n",
1022
- " \n",
1023
- " up_sampling2d_6 (UpSampling2D) (None, 128, 128, 25 0 ['re_lu_18[0][0]'] \n",
1024
- " 6) \n",
1025
- " \n",
1026
- " conv2d_28 (Conv2D) (None, 128, 128, 12 294912 ['up_sampling2d_6[0][0]'] \n",
1027
- " 8) \n",
1028
- " \n",
1029
- " batch_normalization_24 (BatchN (None, 128, 128, 12 512 ['conv2d_28[0][0]'] \n",
1030
- " ormalization) 8) \n",
1031
- " \n",
1032
- " re_lu_19 (ReLU) (None, 128, 128, 12 0 ['batch_normalization_24[0][0]'] \n",
1033
- " 8) \n",
1034
- " \n",
1035
- " up_sampling2d_7 (UpSampling2D) (None, 256, 256, 12 0 ['re_lu_19[0][0]'] \n",
1036
- " 8) \n",
1037
- " \n",
1038
- " conv2d_29 (Conv2D) (None, 256, 256, 64 73728 ['up_sampling2d_7[0][0]'] \n",
1039
- " ) \n",
1040
- " \n",
1041
- " batch_normalization_25 (BatchN (None, 256, 256, 64 256 ['conv2d_29[0][0]'] \n",
1042
- " ormalization) ) \n",
1043
- " \n",
1044
- " re_lu_20 (ReLU) (None, 256, 256, 64 0 ['batch_normalization_25[0][0]'] \n"
1045
- ]
1046
- },
1047
- {
1048
- "name": "stdout",
1049
- "output_type": "stream",
1050
- "text": [
1051
- " ) \n",
1052
- " \n",
1053
- " conv2d_30 (Conv2D) (None, 256, 256, 3) 1728 ['re_lu_20[0][0]'] \n",
1054
- " \n",
1055
- " activation_2 (Activation) (None, 256, 256, 3) 0 ['conv2d_30[0][0]'] \n",
1056
- " \n",
1057
- "==================================================================================================\n",
1058
- "Total params: 28,645,440\n",
1059
- "Trainable params: 28,632,768\n",
1060
- "Non-trainable params: 12,672\n",
1061
- "__________________________________________________________________________________________________\n"
1062
- ]
1063
- }
1064
- ],
1065
- "source": [
1066
- "generator_stage2 = build_stage2_generator()\n",
1067
- "generator_stage2.summary()"
1068
- ]
1069
- },
1070
- {
1071
- "cell_type": "code",
1072
- "execution_count": 31,
1073
- "id": "41de758a",
1074
- "metadata": {},
1075
- "outputs": [],
1076
- "source": [
1077
- "############################################################\n",
1078
- "# Stage 2 Discriminator Network\n",
1079
- "############################################################\n",
1080
- "\n",
1081
- "def build_stage2_discriminator():\n",
1082
- " \"\"\"Builds the Stage 2 Discriminator that uses the 256x256 resolution images from the generator\n",
1083
- " and the compressed and spatially replicated embeddings.\n",
1084
- "\n",
1085
- " Returns:\n",
1086
- " Stage 2 Discriminator Model for StackGAN.\n",
1087
- " \"\"\"\n",
1088
- " input_layer1 = Input(shape=(256, 256, 3))\n",
1089
- "\n",
1090
- " x = Conv2D(64, kernel_size=(4,4), padding='same', strides=2, use_bias=False,\n",
1091
- " kernel_initializer='he_uniform')(input_layer1)\n",
1092
- " x = LeakyReLU(alpha=0.2)(x)\n",
1093
- "\n",
1094
- " x = ConvBlock(x, 128)\n",
1095
- " x = ConvBlock(x, 256)\n",
1096
- " x = ConvBlock(x, 512)\n",
1097
- " x = ConvBlock(x, 1024)\n",
1098
- " x = ConvBlock(x, 2048)\n",
1099
- " x = ConvBlock(x, 1024, (1,1), 1)\n",
1100
- " x = ConvBlock(x, 512, (1,1), 1, False)\n",
1101
- "\n",
1102
- " x1 = ConvBlock(x, 128, (1,1), 1)\n",
1103
- " x1 = ConvBlock(x1, 128, (3,3), 1)\n",
1104
- " x1 = ConvBlock(x1, 512, (3,3), 1, False)\n",
1105
- "\n",
1106
- " x2 = add([x, x1])\n",
1107
- " x2 = LeakyReLU(alpha=0.2)(x2)\n",
1108
- "\n",
1109
- " # Concatenate compressed and spatially replicated embedding\n",
1110
- " input_layer2 = Input(shape=(4, 4, 128))\n",
1111
- " concat = concatenate([x2, input_layer2])\n",
1112
- "\n",
1113
- " x3 = Conv2D(512, kernel_size=(1,1), strides=1, padding='same', kernel_initializer='he_uniform')(concat)\n",
1114
- " x3 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x3)\n",
1115
- " x3 = LeakyReLU(alpha=0.2)(x3)\n",
1116
- "\n",
1117
- " # Flatten and add a FC layer\n",
1118
- " x3 = Flatten()(x3)\n",
1119
- " x3 = Dense(1)(x3)\n",
1120
- " x3 = Activation('sigmoid')(x3)\n",
1121
- "\n",
1122
- " stage2_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x3])\n",
1123
- " return stage2_dis"
1124
- ]
1125
- },
1126
- {
1127
- "cell_type": "code",
1128
- "execution_count": 32,
1129
- "id": "7dbcbc4e",
1130
- "metadata": {},
1131
- "outputs": [
1132
- {
1133
- "name": "stdout",
1134
- "output_type": "stream",
1135
- "text": [
1136
- "Model: \"model_4\"\n",
1137
- "__________________________________________________________________________________________________\n",
1138
- " Layer (type) Output Shape Param # Connected to \n",
1139
- "==================================================================================================\n",
1140
- " input_12 (InputLayer) [(None, 256, 256, 3 0 [] \n",
1141
- " )] \n",
1142
- " \n",
1143
- " conv2d_31 (Conv2D) (None, 128, 128, 64 3072 ['input_12[0][0]'] \n",
1144
- " ) \n",
1145
- " \n",
1146
- " leaky_re_lu_11 (LeakyReLU) (None, 128, 128, 64 0 ['conv2d_31[0][0]'] \n",
1147
- " ) \n",
1148
- " \n",
1149
- " conv2d_32 (Conv2D) (None, 64, 64, 128) 131072 ['leaky_re_lu_11[0][0]'] \n",
1150
- " \n",
1151
- " batch_normalization_26 (BatchN (None, 64, 64, 128) 512 ['conv2d_32[0][0]'] \n",
1152
- " ormalization) \n",
1153
- " \n",
1154
- " leaky_re_lu_12 (LeakyReLU) (None, 64, 64, 128) 0 ['batch_normalization_26[0][0]'] \n",
1155
- " \n",
1156
- " conv2d_33 (Conv2D) (None, 32, 32, 256) 524288 ['leaky_re_lu_12[0][0]'] \n",
1157
- " \n",
1158
- " batch_normalization_27 (BatchN (None, 32, 32, 256) 1024 ['conv2d_33[0][0]'] \n",
1159
- " ormalization) \n",
1160
- " \n",
1161
- " leaky_re_lu_13 (LeakyReLU) (None, 32, 32, 256) 0 ['batch_normalization_27[0][0]'] \n",
1162
- " \n",
1163
- " conv2d_34 (Conv2D) (None, 16, 16, 512) 2097152 ['leaky_re_lu_13[0][0]'] \n",
1164
- " \n",
1165
- " batch_normalization_28 (BatchN (None, 16, 16, 512) 2048 ['conv2d_34[0][0]'] \n",
1166
- " ormalization) \n",
1167
- " \n",
1168
- " leaky_re_lu_14 (LeakyReLU) (None, 16, 16, 512) 0 ['batch_normalization_28[0][0]'] \n",
1169
- " \n",
1170
- " conv2d_35 (Conv2D) (None, 8, 8, 1024) 8388608 ['leaky_re_lu_14[0][0]'] \n",
1171
- " \n",
1172
- " batch_normalization_29 (BatchN (None, 8, 8, 1024) 4096 ['conv2d_35[0][0]'] \n",
1173
- " ormalization) \n",
1174
- " \n",
1175
- " leaky_re_lu_15 (LeakyReLU) (None, 8, 8, 1024) 0 ['batch_normalization_29[0][0]'] \n",
1176
- " \n",
1177
- " conv2d_36 (Conv2D) (None, 4, 4, 2048) 33554432 ['leaky_re_lu_15[0][0]'] \n",
1178
- " \n",
1179
- " batch_normalization_30 (BatchN (None, 4, 4, 2048) 8192 ['conv2d_36[0][0]'] \n",
1180
- " ormalization) \n",
1181
- " \n",
1182
- " leaky_re_lu_16 (LeakyReLU) (None, 4, 4, 2048) 0 ['batch_normalization_30[0][0]'] \n",
1183
- " \n",
1184
- " conv2d_37 (Conv2D) (None, 4, 4, 1024) 2097152 ['leaky_re_lu_16[0][0]'] \n",
1185
- " \n",
1186
- " batch_normalization_31 (BatchN (None, 4, 4, 1024) 4096 ['conv2d_37[0][0]'] \n",
1187
- " ormalization) \n",
1188
- " \n",
1189
- " leaky_re_lu_17 (LeakyReLU) (None, 4, 4, 1024) 0 ['batch_normalization_31[0][0]'] \n",
1190
- " \n",
1191
- " conv2d_38 (Conv2D) (None, 4, 4, 512) 524288 ['leaky_re_lu_17[0][0]'] \n",
1192
- " \n",
1193
- " batch_normalization_32 (BatchN (None, 4, 4, 512) 2048 ['conv2d_38[0][0]'] \n",
1194
- " ormalization) \n",
1195
- " \n",
1196
- " conv2d_39 (Conv2D) (None, 4, 4, 128) 65536 ['batch_normalization_32[0][0]'] \n",
1197
- " \n",
1198
- " batch_normalization_33 (BatchN (None, 4, 4, 128) 512 ['conv2d_39[0][0]'] \n",
1199
- " ormalization) \n",
1200
- " \n",
1201
- " leaky_re_lu_18 (LeakyReLU) (None, 4, 4, 128) 0 ['batch_normalization_33[0][0]'] \n",
1202
- " \n",
1203
- " conv2d_40 (Conv2D) (None, 4, 4, 128) 147456 ['leaky_re_lu_18[0][0]'] \n",
1204
- " \n",
1205
- " batch_normalization_34 (BatchN (None, 4, 4, 128) 512 ['conv2d_40[0][0]'] \n",
1206
- " ormalization) \n",
1207
- " \n",
1208
- " leaky_re_lu_19 (LeakyReLU) (None, 4, 4, 128) 0 ['batch_normalization_34[0][0]'] \n",
1209
- " \n",
1210
- " conv2d_41 (Conv2D) (None, 4, 4, 512) 589824 ['leaky_re_lu_19[0][0]'] \n",
1211
- " \n",
1212
- " batch_normalization_35 (BatchN (None, 4, 4, 512) 2048 ['conv2d_41[0][0]'] \n",
1213
- " ormalization) \n",
1214
- " \n",
1215
- " add_4 (Add) (None, 4, 4, 512) 0 ['batch_normalization_32[0][0]', \n",
1216
- " 'batch_normalization_35[0][0]'] \n",
1217
- " \n",
1218
- " leaky_re_lu_20 (LeakyReLU) (None, 4, 4, 512) 0 ['add_4[0][0]'] \n",
1219
- " \n"
1220
- ]
1221
- },
1222
- {
1223
- "name": "stdout",
1224
- "output_type": "stream",
1225
- "text": [
1226
- " input_13 (InputLayer) [(None, 4, 4, 128)] 0 [] \n",
1227
- " \n",
1228
- " concatenate_2 (Concatenate) (None, 4, 4, 640) 0 ['leaky_re_lu_20[0][0]', \n",
1229
- " 'input_13[0][0]'] \n",
1230
- " \n",
1231
- " conv2d_42 (Conv2D) (None, 4, 4, 512) 328192 ['concatenate_2[0][0]'] \n",
1232
- " \n",
1233
- " batch_normalization_36 (BatchN (None, 4, 4, 512) 2048 ['conv2d_42[0][0]'] \n",
1234
- " ormalization) \n",
1235
- " \n",
1236
- " leaky_re_lu_21 (LeakyReLU) (None, 4, 4, 512) 0 ['batch_normalization_36[0][0]'] \n",
1237
- " \n",
1238
- " flatten_1 (Flatten) (None, 8192) 0 ['leaky_re_lu_21[0][0]'] \n",
1239
- " \n",
1240
- " dense_4 (Dense) (None, 1) 8193 ['flatten_1[0][0]'] \n",
1241
- " \n",
1242
- " activation_3 (Activation) (None, 1) 0 ['dense_4[0][0]'] \n",
1243
- " \n",
1244
- "==================================================================================================\n",
1245
- "Total params: 48,486,401\n",
1246
- "Trainable params: 48,472,833\n",
1247
- "Non-trainable params: 13,568\n",
1248
- "__________________________________________________________________________________________________\n"
1249
- ]
1250
- }
1251
- ],
1252
- "source": [
1253
- "discriminator_stage2 = build_stage2_discriminator()\n",
1254
- "discriminator_stage2.summary()"
1255
- ]
1256
- },
1257
- {
1258
- "cell_type": "code",
1259
- "execution_count": 33,
1260
- "id": "7131179e",
1261
- "metadata": {},
1262
- "outputs": [],
1263
- "source": [
1264
- "############################################################\n",
1265
- "# Stage 2 Adversarial Model\n",
1266
- "############################################################\n",
1267
- "\n",
1268
- "def stage2_adversarial_network(stage2_disc, stage2_gen, stage1_gen):\n",
1269
- " \"\"\"Stage 2 Adversarial Network.\n",
1270
- "\n",
1271
- " Args:\n",
1272
- " stage2_disc: Stage 2 Discriminator Model.\n",
1273
- " stage2_gen: Stage 2 Generator Model.\n",
1274
- " stage1_gen: Stage 1 Generator Model.\n",
1275
- "\n",
1276
- " Returns:\n",
1277
- " Stage 2 Adversarial network.\n",
1278
- " \"\"\"\n",
1279
- " conditioned_embedding = Input(shape=(1024, ))\n",
1280
- " latent_space = Input(shape=(100, ))\n",
1281
- " compressed_replicated = Input(shape=(4, 4, 128))\n",
1282
- " \n",
1283
- " #the discriminator is trained separately and stage1_gen already trained, and this is the reason why we freeze its layers by setting the property trainable=false\n",
1284
- " input_images, ca = stage1_gen([conditioned_embedding, latent_space])\n",
1285
- " stage2_disc.trainable = False\n",
1286
- " stage1_gen.trainable = False\n",
1287
- "\n",
1288
- " images, ca2 = stage2_gen([conditioned_embedding, input_images])\n",
1289
- " probability = stage2_disc([images, compressed_replicated])\n",
1290
- "\n",
1291
- " return Model(inputs=[conditioned_embedding, latent_space, compressed_replicated],\n",
1292
- " outputs=[probability, ca2])"
1293
- ]
1294
- },
1295
- {
1296
- "cell_type": "code",
1297
- "execution_count": 34,
1298
- "id": "a324bec8",
1299
- "metadata": {},
1300
- "outputs": [
1301
- {
1302
- "name": "stdout",
1303
- "output_type": "stream",
1304
- "text": [
1305
- "Model: \"model_5\"\n",
1306
- "__________________________________________________________________________________________________\n",
1307
- " Layer (type) Output Shape Param # Connected to \n",
1308
- "==================================================================================================\n",
1309
- " input_14 (InputLayer) [(None, 1024)] 0 [] \n",
1310
- " \n",
1311
- " input_15 (InputLayer) [(None, 100)] 0 [] \n",
1312
- " \n",
1313
- " model (Functional) [(None, 64, 64, 3), 10270400 ['input_14[0][0]', \n",
1314
- " (None, 256)] 'input_15[0][0]'] \n",
1315
- " \n",
1316
- " model_3 (Functional) [(None, 256, 256, 3 28645440 ['input_14[0][0]', \n",
1317
- " ), 'model[1][0]'] \n",
1318
- " (None, 256)] \n",
1319
- " \n",
1320
- " input_16 (InputLayer) [(None, 4, 4, 128)] 0 [] \n",
1321
- " \n",
1322
- " model_4 (Functional) (None, 1) 48486401 ['model_3[0][0]', \n",
1323
- " 'input_16[0][0]'] \n",
1324
- " \n",
1325
- "==================================================================================================\n",
1326
- "Total params: 87,402,241\n",
1327
- "Trainable params: 28,632,768\n",
1328
- "Non-trainable params: 58,769,473\n",
1329
- "__________________________________________________________________________________________________\n"
1330
- ]
1331
- }
1332
- ],
1333
- "source": [
1334
- "adversarial_stage2 = stage2_adversarial_network(discriminator_stage2, generator_stage2, generator)\n",
1335
- "adversarial_stage2.summary()"
1336
- ]
1337
- },
1338
- {
1339
- "cell_type": "code",
1340
- "execution_count": 35,
1341
- "id": "75ce4927",
1342
- "metadata": {},
1343
- "outputs": [],
1344
- "source": [
1345
- "class StackGanStage2(object):\n",
1346
- " \"\"\"StackGAN Stage 2 class.\n",
1347
- "\n",
1348
- " Args:\n",
1349
- " epochs: Number of epochs\n",
1350
- " z_dim: Latent space dimensions\n",
1351
- " batch_size: Batch Size\n",
1352
- " enable_function: If True, training function is decorated with tf.function\n",
1353
- " stage2_generator_lr: Learning rate for stage 2 generator\n",
1354
- " stage2_discriminator_lr: Learning rate for stage 2 discriminator\n",
1355
- " \"\"\"\n",
1356
- " def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage2_generator_lr=0.0002, stage2_discriminator_lr=0.0002):\n",
1357
- " self.epochs = epochs\n",
1358
- " self.z_dim = z_dim\n",
1359
- " self.enable_function = enable_function\n",
1360
- " self.stage1_generator_lr = stage2_generator_lr\n",
1361
- " self.stage1_discriminator_lr = stage2_discriminator_lr\n",
1362
- " self.low_image_size = 64\n",
1363
- " self.high_image_size = 256\n",
1364
- " self.conditioning_dim = 128\n",
1365
- " self.batch_size = batch_size\n",
1366
- " self.stage2_generator_optimizer = Adam(lr=stage2_generator_lr, beta_1=0.5, beta_2=0.999)\n",
1367
- " self.stage2_discriminator_optimizer = Adam(lr=stage2_discriminator_lr, beta_1=0.5, beta_2=0.999)\n",
1368
- " self.stage1_generator = build_stage1_generator()\n",
1369
- " self.stage1_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)\n",
1370
- " self.stage1_generator.load_weights('weights/stage1_gen.h5')\n",
1371
- " self.stage2_generator = build_stage2_generator()\n",
1372
- " self.stage2_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)\n",
1373
- "\n",
1374
- " self.stage2_discriminator = build_stage2_discriminator()\n",
1375
- " self.stage2_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage2_discriminator_optimizer)\n",
1376
- "\n",
1377
- " self.ca_network = build_ca_network()\n",
1378
- " self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')\n",
1379
- "\n",
1380
- " self.embedding_compressor = build_embedding_compressor()\n",
1381
- " self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')\n",
1382
- "\n",
1383
- " self.stage2_adversarial = stage2_adversarial_network(self.stage2_discriminator, self.stage2_generator, self.stage1_generator)\n",
1384
- " self.stage2_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage2_generator_optimizer)\t\n",
1385
- "\n",
1386
- " self.checkpoint2 = tf.train.Checkpoint(\n",
1387
- " generator_optimizer=self.stage2_generator_optimizer,\n",
1388
- " discriminator_optimizer=self.stage2_discriminator_optimizer,\n",
1389
- " generator=self.stage2_generator,\n",
1390
- " discriminator=self.stage2_discriminator,\n",
1391
- " generator1=self.stage1_generator)\n",
1392
- "\n",
1393
- " def visualize_stage2(self):\n",
1394
- " \"\"\"Running Tensorboard visualizations.\n",
1395
- " \"\"\"\n",
1396
- " tb = TensorBoard(log_dir=\"logs/\".format(time.time()))\n",
1397
- " tb.set_model(self.stage2_generator)\n",
1398
- " tb.set_model(self.stage2_discriminator)\n",
1399
- "\n",
1400
- " def train_stage2(self):\n",
1401
- " \"\"\"Trains Stage 2 StackGAN.\n",
1402
- " \"\"\"\n",
1403
- " x_high_train, y_high_train, high_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n",
1404
- " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(256, 256))\n",
1405
- "\n",
1406
- " x_high_test, y_high_test, high_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n",
1407
- " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(256, 256))\n",
1408
- "\n",
1409
- " x_low_train, y_low_train, low_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n",
1410
- " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))\n",
1411
- "\n",
1412
- " x_low_test, y_low_test, low_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n",
1413
- " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))\n",
1414
- "\n",
1415
- " real = np.ones((self.batch_size, 1), dtype='float') * 0.9\n",
1416
- " fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1\n",
1417
- "\n",
1418
- " for epoch in range(self.epochs):\n",
1419
- " print(f'Epoch: {epoch}')\n",
1420
- "\n",
1421
- " gen_loss = []\n",
1422
- " disc_loss = []\n",
1423
- "\n",
1424
- " num_batches = int(x_high_train.shape[0] / self.batch_size)\n",
1425
- "\n",
1426
- " for i in range(num_batches):\n",
1427
- "\n",
1428
- " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n",
1429
- " embedding_text = high_train_embeds[i * self.batch_size:(i + 1) * self.batch_size]\n",
1430
- " compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)\n",
1431
- " compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, self.conditioning_dim))\n",
1432
- " compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))\n",
1433
- "\n",
1434
- " image_batch = x_high_train[i * self.batch_size:(i+1) * self.batch_size]\n",
1435
- " image_batch = (image_batch - 127.5) / 127.5\n",
1436
- " \n",
1437
- " low_res_fakes, _ = self.stage1_generator.predict([embedding_text, latent_space], verbose=3)\n",
1438
- " high_res_fakes, _ = self.stage2_generator.predict([embedding_text, low_res_fakes], verbose=3)\n",
1439
- "\n",
1440
- " discriminator_loss = self.stage2_discriminator.train_on_batch([image_batch, compressed_embedding],\n",
1441
- " np.reshape(real, (self.batch_size, 1)))\n",
1442
- "\n",
1443
- " discriminator_loss_gen = self.stage2_discriminator.train_on_batch([high_res_fakes, compressed_embedding],\n",
1444
- " np.reshape(fake, (self.batch_size, 1)))\n",
1445
- "\n",
1446
- " discriminator_loss_fake = self.stage2_discriminator.train_on_batch([image_batch[:(self.batch_size-1)], compressed_embedding[1:]],\n",
1447
- " np.reshape(fake[1:], (self.batch_size - 1, 1)))\n",
1448
- "\n",
1449
- " d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_fake))\n",
1450
- " disc_loss.append(d_loss)\n",
1451
- "\n",
1452
- " print(f'Discriminator Loss: {d_loss}')\n",
1453
- "\n",
1454
- " g_loss = self.stage2_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],\n",
1455
- " [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])\n",
1456
- " gen_loss.append(g_loss)\n",
1457
- "\n",
1458
- " print(f'Generator Loss: {g_loss}')\n",
1459
- "\n",
1460
- " if epoch % 5 == 0:\n",
1461
- " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n",
1462
- " embedding_batch = high_test_embeds[0 : self.batch_size]\n",
1463
- "\n",
1464
- " low_fake_images, _ = self.stage1_generator.predict([embedding_batch, latent_space], verbose=3)\n",
1465
- " high_fake_images, _ = self.stage2_generator.predict([embedding_batch, low_fake_images], verbose=3)\n",
1466
- "\n",
1467
- " for i, image in enumerate(high_fake_images[:10]):\n",
1468
- " save_image(image, f'results_stage2/gen_{epoch}_{i}.png')\n",
1469
- "\n",
1470
- " if epoch % 10 == 0:\n",
1471
- " self.stage2_generator.save_weights('weights/stage2_gen.h5')\n",
1472
- " self.stage2_discriminator.save_weights(\"weights/stage2_disc.h5\")\n",
1473
- " self.ca_network.save_weights('weights/stage2_ca.h5')\n",
1474
- " self.embedding_compressor.save_weights('weights/stage2_embco.h5')\n",
1475
- " self.stage2_adversarial.save_weights('weights/stage2_adv.h5')\n",
1476
- "\n",
1477
- " self.stage2_generator.save_weights('weights/stage2_gen.h5')\n",
1478
- " self.stage2_discriminator.save_weights(\"weights/stage2_disc.h5\")"
1479
- ]
1480
- },
1481
- {
1482
- "cell_type": "code",
1483
- "execution_count": null,
1484
- "id": "0a91a164",
1485
- "metadata": {},
1486
- "outputs": [],
1487
- "source": [
1488
- "stage2 = StackGanStage2()\n",
1489
- "stage2.train_stage2()"
1490
- ]
1491
- }
1492
- ],
1493
- "metadata": {
1494
- "kernelspec": {
1495
- "display_name": "Python 3 (ipykernel)",
1496
- "language": "python",
1497
- "name": "python3"
1498
- },
1499
- "language_info": {
1500
- "codemirror_mode": {
1501
- "name": "ipython",
1502
- "version": 3
1503
- },
1504
- "file_extension": ".py",
1505
- "mimetype": "text/x-python",
1506
- "name": "python",
1507
- "nbconvert_exporter": "python",
1508
- "pygments_lexer": "ipython3",
1509
- "version": "3.10.9"
1510
- }
1511
- },
1512
- "nbformat": 4,
1513
- "nbformat_minor": 5
1514
- }