{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a3677b66", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import os\n", "import pickle\n", "import time\n", "import random" ] }, { "cell_type": "code", "execution_count": 8, "id": "76ece7f8", "metadata": {}, "outputs": [], "source": [ "import PIL\n", "from PIL import Image\n", "import keras.backend as K\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from keras.optimizers import Adam\n", "from keras.models import Sequential\n", "from keras import layers,Model,Input\n", "from keras.layers import Lambda,Reshape,UpSampling2D,ReLU,add,ZeroPadding2D\n", "from keras.layers import Activation,BatchNormalization,Concatenate,concatenate\n", "from keras.layers import Dense,Conv2D,Flatten,Dropout,LeakyReLU\n", "from keras.preprocessing.image import ImageDataGenerator" ] }, { "cell_type": "markdown", "id": "b8980cd5", "metadata": {}, "source": [ "### Conditioning Augmentation Network" ] }, { "cell_type": "code", "execution_count": 3, "id": "d3027cda", "metadata": {}, "outputs": [], "source": [ "# conditioned by the text.\n", "def conditioning_augmentation(x):\n", " \"\"\"The mean_logsigma passed as argument is converted into the text conditioning variable.\n", "\n", " Args:\n", " x: The output of the text embedding passed through a FC layer with LeakyReLU non-linearity.\n", "\n", " Returns:\n", " c: The text conditioning variable after computation.\n", " \"\"\"\n", " mean = x[:, :128]\n", " log_sigma = x[:, 128:]\n", "\n", " stddev = tf.math.exp(log_sigma)\n", " epsilon = K.random_normal(shape=K.constant((mean.shape[1], ), dtype='int32'))\n", " c = mean + stddev * epsilon\n", " return c\n", "\n", "def build_ca_network():\n", " \"\"\"Builds the conditioning augmentation network.\n", " \"\"\"\n", " input_layer1 = Input(shape=(1024,)) #size of the vocabulary in the text data\n", " mls = Dense(256)(input_layer1)\n", " mls = LeakyReLU(alpha=0.2)(mls)\n", " ca = Lambda(conditioning_augmentation)(mls)\n", " return Model(inputs=[input_layer1], outputs=[ca]) " ] }, { "cell_type": "markdown", "id": "87340e8b", "metadata": {}, "source": [ "### Stage 1 Generator Network" ] }, { "cell_type": "code", "execution_count": 4, "id": "c430524d", "metadata": {}, "outputs": [], "source": [ "def UpSamplingBlock(x, num_kernels):\n", " \"\"\"An Upsample block with Upsampling2D, Conv2D, BatchNormalization and a ReLU activation.\n", "\n", " Args:\n", " x: The preceding layer as input.\n", " num_kernels: Number of kernels for the Conv2D layer.\n", "\n", " Returns:\n", " x: The final activation layer after the Upsampling block.\n", " \"\"\"\n", " x = UpSampling2D(size=(2,2))(x)\n", " x = Conv2D(num_kernels, kernel_size=(3,3), padding='same', strides=1, use_bias=False,\n", " kernel_initializer='he_uniform')(x)\n", " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x) #prevent from mode collapse\n", " x = ReLU()(x)\n", " return x\n", "\n", "\n", "def build_stage1_generator():\n", "\n", " input_layer1 = Input(shape=(1024,))\n", " ca = Dense(256)(input_layer1)\n", " ca = LeakyReLU(alpha=0.2)(ca)\n", "\n", " # Obtain the conditioned text\n", " c = Lambda(conditioning_augmentation)(ca)\n", "\n", " input_layer2 = Input(shape=(100,))\n", " concat = Concatenate(axis=1)([c, input_layer2]) \n", "\n", " x = Dense(16384, use_bias=False)(concat) \n", " x = ReLU()(x)\n", " x = Reshape((4, 4, 1024), input_shape=(16384,))(x)\n", "\n", " x = UpSamplingBlock(x, 512) \n", " x = UpSamplingBlock(x, 256)\n", " x = UpSamplingBlock(x, 128)\n", " x = UpSamplingBlock(x, 64) # upsampled our image to 64*64*3 \n", "\n", " x = Conv2D(3, kernel_size=3, padding='same', strides=1, use_bias=False,\n", " kernel_initializer='he_uniform')(x)\n", " x = Activation('tanh')(x)\n", "\n", " stage1_gen = Model(inputs=[input_layer1, input_layer2], outputs=[x, ca]) \n", " return stage1_gen" ] }, { "cell_type": "code", "execution_count": 5, "id": "0febcb4f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_1 (InputLayer) [(None, 1024)] 0 [] \n", " \n", " dense (Dense) (None, 256) 262400 ['input_1[0][0]'] \n", " \n", " leaky_re_lu (LeakyReLU) (None, 256) 0 ['dense[0][0]'] \n", " \n", " lambda (Lambda) (None, 128) 0 ['leaky_re_lu[0][0]'] \n", " \n", " input_2 (InputLayer) [(None, 100)] 0 [] \n", " \n", " concatenate (Concatenate) (None, 228) 0 ['lambda[0][0]', \n", " 'input_2[0][0]'] \n", " \n", " dense_1 (Dense) (None, 16384) 3735552 ['concatenate[0][0]'] \n", " \n", " re_lu (ReLU) (None, 16384) 0 ['dense_1[0][0]'] \n", " \n", " reshape (Reshape) (None, 4, 4, 1024) 0 ['re_lu[0][0]'] \n", " \n", " up_sampling2d (UpSampling2D) (None, 8, 8, 1024) 0 ['reshape[0][0]'] \n", " \n", " conv2d (Conv2D) (None, 8, 8, 512) 4718592 ['up_sampling2d[0][0]'] \n", " \n", " batch_normalization (BatchNorm (None, 8, 8, 512) 2048 ['conv2d[0][0]'] \n", " alization) \n", " \n", " re_lu_1 (ReLU) (None, 8, 8, 512) 0 ['batch_normalization[0][0]'] \n", " \n", " up_sampling2d_1 (UpSampling2D) (None, 16, 16, 512) 0 ['re_lu_1[0][0]'] \n", " \n", " conv2d_1 (Conv2D) (None, 16, 16, 256) 1179648 ['up_sampling2d_1[0][0]'] \n", " \n", " batch_normalization_1 (BatchNo (None, 16, 16, 256) 1024 ['conv2d_1[0][0]'] \n", " rmalization) \n", " \n", " re_lu_2 (ReLU) (None, 16, 16, 256) 0 ['batch_normalization_1[0][0]'] \n", " \n", " up_sampling2d_2 (UpSampling2D) (None, 32, 32, 256) 0 ['re_lu_2[0][0]'] \n", " \n", " conv2d_2 (Conv2D) (None, 32, 32, 128) 294912 ['up_sampling2d_2[0][0]'] \n", " \n", " batch_normalization_2 (BatchNo (None, 32, 32, 128) 512 ['conv2d_2[0][0]'] \n", " rmalization) \n", " \n", " re_lu_3 (ReLU) (None, 32, 32, 128) 0 ['batch_normalization_2[0][0]'] \n", " \n", " up_sampling2d_3 (UpSampling2D) (None, 64, 64, 128) 0 ['re_lu_3[0][0]'] \n", " \n", " conv2d_3 (Conv2D) (None, 64, 64, 64) 73728 ['up_sampling2d_3[0][0]'] \n", " \n", " batch_normalization_3 (BatchNo (None, 64, 64, 64) 256 ['conv2d_3[0][0]'] \n", " rmalization) \n", " \n", " re_lu_4 (ReLU) (None, 64, 64, 64) 0 ['batch_normalization_3[0][0]'] \n", " \n", " conv2d_4 (Conv2D) (None, 64, 64, 3) 1728 ['re_lu_4[0][0]'] \n", " \n", " activation (Activation) (None, 64, 64, 3) 0 ['conv2d_4[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 10,270,400\n", "Trainable params: 10,268,480\n", "Non-trainable params: 1,920\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "generator = build_stage1_generator()\n", "generator.summary()" ] }, { "cell_type": "markdown", "id": "a14d9d1c", "metadata": {}, "source": [ "### Stage 1 Discriminator Network" ] }, { "cell_type": "code", "execution_count": 9, "id": "32b436ac", "metadata": {}, "outputs": [], "source": [ "def ConvBlock(x, num_kernels, kernel_size=(4,4), strides=2, activation=True):\n", " \"\"\"A ConvBlock with a Conv2D, BatchNormalization and LeakyReLU activation.\n", "\n", " Args:\n", " x: The preceding layer as input.\n", " num_kernels: Number of kernels for the Conv2D layer.\n", "\n", " Returns:\n", " x: The final activation layer after the ConvBlock block.\n", " \"\"\"\n", " x = Conv2D(num_kernels, kernel_size=kernel_size, padding='same', strides=strides, use_bias=False,\n", " kernel_initializer='he_uniform')(x)\n", " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n", " \n", " if activation:\n", " x = LeakyReLU(alpha=0.2)(x)\n", " return x\n", "\n", "\n", "def build_embedding_compressor():\n", " \"\"\"Build embedding compressor model\n", " \"\"\"\n", " input_layer1 = Input(shape=(1024,)) \n", " x = Dense(128)(input_layer1)\n", " x = ReLU()(x)\n", "\n", " model = Model(inputs=[input_layer1], outputs=[x])\n", " return model\n", "\n", "# the discriminator is fed with two inputs, the feature from Generator and the text embedding\n", "def build_stage1_discriminator():\n", " \"\"\"Builds the Stage 1 Discriminator that uses the 64x64 resolution images from the generator\n", " and the compressed and spatially replicated embedding.\n", "\n", " Returns:\n", " Stage 1 Discriminator Model for StackGAN.\n", " \"\"\"\n", " input_layer1 = Input(shape=(64, 64, 3)) \n", "\n", " x = Conv2D(64, kernel_size=(4,4), strides=2, padding='same', use_bias=False,\n", " kernel_initializer='he_uniform')(input_layer1)\n", " x = LeakyReLU(alpha=0.2)(x)\n", "\n", " x = ConvBlock(x, 128)\n", " x = ConvBlock(x, 256)\n", " x = ConvBlock(x, 512)\n", "\n", " # Obtain the compressed and spatially replicated text embedding\n", " input_layer2 = Input(shape=(4, 4, 128)) #2nd input to discriminator, text embedding\n", " concat = concatenate([x, input_layer2])\n", "\n", " x1 = Conv2D(512, kernel_size=(1,1), padding='same', strides=1, use_bias=False,\n", " kernel_initializer='he_uniform')(concat)\n", " x1 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n", " x1 = LeakyReLU(alpha=0.2)(x)\n", "\n", " # Flatten and add a FC layer to predict.\n", " x1 = Flatten()(x1)\n", " x1 = Dense(1)(x1)\n", " x1 = Activation('sigmoid')(x1)\n", "\n", " stage1_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x1]) \n", " return stage1_dis" ] }, { "cell_type": "code", "execution_count": 10, "id": "98090438", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_1\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_5 (InputLayer) [(None, 64, 64, 3)] 0 [] \n", " \n", " conv2d_9 (Conv2D) (None, 32, 32, 64) 3072 ['input_5[0][0]'] \n", " \n", " leaky_re_lu_5 (LeakyReLU) (None, 32, 32, 64) 0 ['conv2d_9[0][0]'] \n", " \n", " conv2d_10 (Conv2D) (None, 16, 16, 128) 131072 ['leaky_re_lu_5[0][0]'] \n", " \n", " batch_normalization_7 (BatchNo (None, 16, 16, 128) 512 ['conv2d_10[0][0]'] \n", " rmalization) \n", " \n", " leaky_re_lu_6 (LeakyReLU) (None, 16, 16, 128) 0 ['batch_normalization_7[0][0]'] \n", " \n", " conv2d_11 (Conv2D) (None, 8, 8, 256) 524288 ['leaky_re_lu_6[0][0]'] \n", " \n", " batch_normalization_8 (BatchNo (None, 8, 8, 256) 1024 ['conv2d_11[0][0]'] \n", " rmalization) \n", " \n", " leaky_re_lu_7 (LeakyReLU) (None, 8, 8, 256) 0 ['batch_normalization_8[0][0]'] \n", " \n", " conv2d_12 (Conv2D) (None, 4, 4, 512) 2097152 ['leaky_re_lu_7[0][0]'] \n", " \n", " batch_normalization_9 (BatchNo (None, 4, 4, 512) 2048 ['conv2d_12[0][0]'] \n", " rmalization) \n", " \n", " leaky_re_lu_8 (LeakyReLU) (None, 4, 4, 512) 0 ['batch_normalization_9[0][0]'] \n", " \n", " leaky_re_lu_9 (LeakyReLU) (None, 4, 4, 512) 0 ['leaky_re_lu_8[0][0]'] \n", " \n", " flatten (Flatten) (None, 8192) 0 ['leaky_re_lu_9[0][0]'] \n", " \n", " dense_2 (Dense) (None, 1) 8193 ['flatten[0][0]'] \n", " \n", " input_6 (InputLayer) [(None, 4, 4, 128)] 0 [] \n", " \n", " activation_1 (Activation) (None, 1) 0 ['dense_2[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 2,767,361\n", "Trainable params: 2,765,569\n", "Non-trainable params: 1,792\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "discriminator = build_stage1_discriminator()\n", "discriminator.summary()" ] }, { "cell_type": "markdown", "id": "cdc2a75a", "metadata": {}, "source": [ "### Stage 1 Adversarial Model (Building a GAN)" ] }, { "cell_type": "code", "execution_count": 11, "id": "5d0678f7", "metadata": {}, "outputs": [], "source": [ "# Building GAN with Generator and Discriminator\n", "\n", "def build_adversarial(generator_model, discriminator_model):\n", " \"\"\"Stage 1 Adversarial model.\n", "\n", " Args:\n", " generator_model: Stage 1 Generator Model\n", " discriminator_model: Stage 1 Discriminator Model\n", "\n", " Returns:\n", " Adversarial Model.\n", " \"\"\"\n", " input_layer1 = Input(shape=(1024,)) \n", " input_layer2 = Input(shape=(100,)) \n", " input_layer3 = Input(shape=(4, 4, 128)) \n", "\n", " x, ca = generator_model([input_layer1, input_layer2]) #text,noise\n", "\n", " discriminator_model.trainable = False \n", "\n", " probabilities = discriminator_model([x, input_layer3]) \n", " adversarial_model = Model(inputs=[input_layer1, input_layer2, input_layer3], outputs=[probabilities, ca])\n", " return adversarial_model" ] }, { "cell_type": "code", "execution_count": 12, "id": "bd351c9d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_2\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_7 (InputLayer) [(None, 1024)] 0 [] \n", " \n", " input_8 (InputLayer) [(None, 100)] 0 [] \n", " \n", " model (Functional) [(None, 64, 64, 3), 10270400 ['input_7[0][0]', \n", " (None, 256)] 'input_8[0][0]'] \n", " \n", " input_9 (InputLayer) [(None, 4, 4, 128)] 0 [] \n", " \n", " model_1 (Functional) (None, 1) 2767361 ['model[0][0]', \n", " 'input_9[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 13,037,761\n", "Trainable params: 10,268,480\n", "Non-trainable params: 2,769,281\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "ganstage1 = build_adversarial(generator, discriminator)\n", "ganstage1.summary()" ] }, { "cell_type": "markdown", "id": "adf70416", "metadata": {}, "source": [ "### Train Utilities" ] }, { "cell_type": "code", "execution_count": 13, "id": "730c9e8a", "metadata": {}, "outputs": [], "source": [ "def checkpoint_prefix():\n", " checkpoint_dir = './training_checkpoints'\n", " checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')\n", "\n", " return checkpoint_prefix\n", "\n", "def adversarial_loss(y_true, y_pred):\n", " mean = y_pred[:, :128]\n", " ls = y_pred[:, 128:]\n", " loss = -ls + 0.5 * (-1 + tf.math.exp(2.0 * ls) + tf.math.square(mean))\n", " loss = K.mean(loss)\n", " return loss\n", "\n", "def normalize(input_image, real_image):\n", " input_image = (input_image / 127.5) - 1\n", " real_image = (real_image / 127.5) - 1\n", "\n", " return input_image, real_image\n", "\n", "def load_class_ids_filenames(class_id_path, filename_path):\n", " with open(class_id_path, 'rb') as file:\n", " class_id = pickle.load(file, encoding='latin1')\n", "\n", " with open(filename_path, 'rb') as file:\n", " filename = pickle.load(file, encoding='latin1')\n", "\n", " return class_id, filename\n", "\n", "def load_text_embeddings(text_embeddings):\n", " with open(text_embeddings, 'rb') as file:\n", " embeds = pickle.load(file, encoding='latin1')\n", " embeds = np.array(embeds)\n", "\n", " return embeds\n", "\n", "def load_bbox(data_path):\n", " bbox_path = data_path + '/bounding_boxes.txt'\n", " image_path = data_path + '/images.txt'\n", " bbox_df = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)\n", " filename_df = pd.read_csv(image_path, delim_whitespace=True, header=None)\n", "\n", " filenames = filename_df[1].tolist()\n", " bbox_dict = {i[:-4]:[] for i in filenames[:2]}\n", "\n", " for i in range(0, len(filenames)):\n", " bbox = bbox_df.iloc[i][1:].tolist()\n", " dict_key = filenames[i][:-4]\n", " bbox_dict[dict_key] = bbox\n", "\n", " return bbox_dict\n", "\n", "def load_images(image_path, bounding_box, size):\n", " \"\"\"Crops the image to the bounding box and then resizes it.\n", " \"\"\"\n", " image = Image.open(image_path).convert('RGB')\n", " w, h = image.size\n", " if bounding_box is not None:\n", " r = int(np.maximum(bounding_box[2], bounding_box[3]) * 0.75)\n", " c_x = int((bounding_box[0] + bounding_box[2]) / 2)\n", " c_y = int((bounding_box[1] + bounding_box[3]) / 2)\n", " y1 = np.maximum(0, c_y - r)\n", " y2 = np.minimum(h, c_y + r)\n", " x1 = np.maximum(0, c_x - r)\n", " x2 = np.minimum(w, c_x + r)\n", " image = image.crop([x1, y1, x2, y2])\n", "\n", " image = image.resize(size, PIL.Image.BILINEAR)\n", " return image\n", "\n", "def load_data(filename_path, class_id_path, dataset_path, embeddings_path, size):\n", " \"\"\"Loads the Dataset.\n", " \"\"\"\n", " data_dir = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds\"\n", " train_dir = data_dir + \"/train\"\n", " test_dir = data_dir + \"/test\"\n", " embeddings_path_train = train_dir + \"/char-CNN-RNN-embeddings.pickle\"\n", " embeddings_path_test = test_dir + \"/char-CNN-RNN-embeddings.pickle\"\n", " filename_path_train = train_dir + \"/filenames.pickle\"\n", " filename_path_test = test_dir + \"/filenames.pickle\"\n", " class_id_path_train = train_dir + \"/class_info.pickle\"\n", " class_id_path_test = test_dir + \"/class_info.pickle\"\n", " dataset_path = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011\"\n", " class_id, filenames = load_class_ids_filenames(class_id_path, filename_path)\n", " embeddings = load_text_embeddings(embeddings_path)\n", " bbox_dict = load_bbox(dataset_path)\n", "\n", " x, y, embeds = [], [], []\n", "\n", " for i, filename in enumerate(filenames):\n", " bbox = bbox_dict[filename]\n", "\n", " try:\n", " image_path = f'{dataset_path}/images/{filename}.jpg'\n", " image = load_images(image_path, bbox, size)\n", " e = embeddings[i, :, :]\n", " embed_index = np.random.randint(0, e.shape[0] - 1)\n", " embed = e[embed_index, :]\n", "\n", " x.append(np.array(image))\n", " y.append(class_id[i])\n", " embeds.append(embed)\n", "\n", " except Exception as e:\n", " print(f'{e}')\n", " \n", " x = np.array(x)\n", " y = np.array(y)\n", " embeds = np.array(embeds)\n", " \n", " return x, y, embeds\n", "\n", "def save_image(file, save_path):\n", " \"\"\"Saves the image at the specified file path.\n", " \"\"\"\n", " image = plt.figure()\n", " ax = image.add_subplot(1,1,1)\n", " ax.imshow(file)\n", " ax.axis(\"off\")\n", " plt.savefig(save_path)" ] }, { "cell_type": "code", "execution_count": 28, "id": "697f1dc6", "metadata": {}, "outputs": [], "source": [ "############################################################\n", "# StackGAN class\n", "############################################################\n", "\n", "class StackGanStage1(object):\n", " \"\"\"StackGAN Stage 1 class.\"\"\"\n", "\n", " data_dir = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds\"\n", " train_dir = data_dir + \"/train\"\n", " test_dir = data_dir + \"/test\"\n", " embeddings_path_train = train_dir + \"/char-CNN-RNN-embeddings.pickle\"\n", " embeddings_path_test = test_dir + \"/char-CNN-RNN-embeddings.pickle\"\n", " filename_path_train = train_dir + \"/filenames.pickle\"\n", " filename_path_test = test_dir + \"/filenames.pickle\"\n", " class_id_path_train = train_dir + \"/class_info.pickle\"\n", " class_id_path_test = test_dir + \"/class_info.pickle\"\n", " dataset_path = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011\"\n", " def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage1_generator_lr=0.0002, stage1_discriminator_lr=0.0002):\n", " self.epochs = epochs\n", " self.z_dim = z_dim\n", " self.enable_function = enable_function\n", " self.stage1_generator_lr = stage1_generator_lr\n", " self.stage1_discriminator_lr = stage1_discriminator_lr\n", " self.image_size = 64\n", " self.conditioning_dim = 128\n", " self.batch_size = batch_size\n", " \n", " self.stage1_generator_optimizer = Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)\n", " self.stage1_discriminator_optimizer = Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)\n", " \n", " self.stage1_generator = build_stage1_generator()\n", " self.stage1_generator.compile(loss='mse', optimizer=self.stage1_generator_optimizer)\n", "\n", " self.stage1_discriminator = build_stage1_discriminator()\n", " self.stage1_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage1_discriminator_optimizer)\n", "\n", " self.ca_network = build_ca_network()\n", " self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')\n", "\n", " self.embedding_compressor = build_embedding_compressor()\n", " self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')\n", "\n", " self.stage1_adversarial = build_adversarial(self.stage1_generator, self.stage1_discriminator)\n", " self.stage1_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage1_generator_optimizer)\n", "\n", " self.checkpoint1 = tf.train.Checkpoint(\n", " generator_optimizer=self.stage1_generator_optimizer,\n", " discriminator_optimizer=self.stage1_discriminator_optimizer,\n", " generator=self.stage1_generator,\n", " discriminator=self.stage1_discriminator)\n", "\n", " def visualize_stage1(self):\n", " \"\"\"Running Tensorboard visualizations.\n", " \"\"\"\n", " tb = TensorBoard(log_dir=\"logs/\".format(time.time()))\n", " tb.set_model(self.stage1_generator)\n", " tb.set_model(self.stage1_discriminator)\n", " tb.set_model(self.ca_network)\n", " tb.set_model(self.embedding_compressor)\n", "\n", " def train_stage1(self):\n", " \"\"\"Trains the stage1 StackGAN.\n", " \"\"\"\n", " x_train, y_train, train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n", " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))\n", "\n", " x_test, y_test, test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n", " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))\n", "\n", " real = np.ones((self.batch_size, 1), dtype='float') * 0.9\n", " fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1\n", "\n", " for epoch in range(self.epochs):\n", " print(f'Epoch: {epoch}')\n", "\n", " gen_loss = []\n", " dis_loss = []\n", "\n", " num_batches = int(x_train.shape[0] / self.batch_size)\n", "\n", " for i in range(num_batches):\n", "\n", " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n", " embedding_text = train_embeds[i * self.batch_size:(i + 1) * self.batch_size]\n", " compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)\n", " compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, 128))\n", " compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))\n", "\n", " image_batch = x_train[i * self.batch_size:(i+1) * self.batch_size]\n", " image_batch = (image_batch - 127.5) / 127.5\n", "\n", " gen_images, _ = self.stage1_generator.predict([embedding_text, latent_space])\n", "\n", " discriminator_loss = self.stage1_discriminator.train_on_batch([image_batch, compressed_embedding], \n", " np.reshape(real, (self.batch_size, 1)))\n", "\n", " discriminator_loss_gen = self.stage1_discriminator.train_on_batch([gen_images, compressed_embedding],\n", " np.reshape(fake, (self.batch_size, 1)))\n", "\n", " discriminator_loss_wrong = self.stage1_discriminator.train_on_batch([gen_images[: self.batch_size-1], compressed_embedding[1:]], \n", " np.reshape(fake[1:], (self.batch_size-1, 1)))\n", "\n", "# Discriminator loss\n", " d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_wrong))\n", " dis_loss.append(d_loss)\n", "\n", " print(f'Discriminator Loss: {d_loss}')\n", "\n", " # Generator loss\n", " g_loss = self.stage1_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],\n", " [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])\n", "\n", " print(f'Generator Loss: {g_loss}')\n", " gen_loss.append(g_loss)\n", "\n", " if epoch % 5 == 0:\n", " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n", " embedding_batch = test_embeds[0 : self.batch_size]\n", " gen_images, _ = self.stage1_generator.predict_on_batch([embedding_batch, latent_space])\n", "\n", " for i, image in enumerate(gen_images[:10]):\n", " save_image(image, f'test/gen_1_{epoch}_{i}')\n", "\n", " if epoch % 25 == 0:\n", " self.stage1_generator.save_weights('weights/stage1_gen.h5')\n", " self.stage1_discriminator.save_weights(\"weights/stage1_disc.h5\")\n", " self.ca_network.save_weights('weights/stage1_ca.h5')\n", " self.embedding_compressor.save_weights('weights/stage1_embco.h5')\n", " self.stage1_adversarial.save_weights('weights/stage1_adv.h5') \n", "\n", " self.stage1_generator.save_weights('weights/stage1_gen.h5')\n", " self.stage1_discriminator.save_weights(\"weights/stage1_disc.h5\")" ] }, { "cell_type": "code", "execution_count": null, "id": "517037ac", "metadata": {}, "outputs": [], "source": [ "stage1 = StackGanStage1()\n", "stage1.train_stage1()" ] }, { "cell_type": "markdown", "id": "7d85b9f2", "metadata": {}, "source": [ "### Check test folder for gernerated images from Stage1 Generator\n", "### Let's Implement Stage 2 Generator" ] }, { "cell_type": "code", "execution_count": 29, "id": "2e45c731", "metadata": {}, "outputs": [], "source": [ "############################################################\n", "# Stage 2 Generator Network\n", "############################################################\n", "\n", "def concat_along_dims(inputs):\n", " \"\"\"Joins the conditioned text with the encoded image along the dimensions.\n", "\n", " Args:\n", " inputs: consisting of conditioned text and encoded images as [c,x].\n", "\n", " Returns:\n", " Joint block along the dimensions.\n", " \"\"\"\n", " c = inputs[0]\n", " x = inputs[1]\n", "\n", " c = K.expand_dims(c, axis=1)\n", " c = K.expand_dims(c, axis=1)\n", " c = K.tile(c, [1, 16, 16, 1])\n", " return K.concatenate([c, x], axis = 3)\n", "\n", "def residual_block(input):\n", " \"\"\"Residual block with plain identity connections.\n", "\n", " Args:\n", " inputs: input layer or an encoded layer\n", "\n", " Returns:\n", " Layer with computed identity mapping.\n", " \"\"\"\n", " x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,\n", " kernel_initializer='he_uniform')(input)\n", " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n", " x = ReLU()(x)\n", " \n", " x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,\n", " kernel_initializer='he_uniform')(x)\n", " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n", " \n", " x = add([x, input])\n", " x = ReLU()(x)\n", "\n", " return x\n", "\n", "def build_stage2_generator():\n", " \"\"\"Build the Stage 2 Generator Network using the conditioning text and images from stage 1.\n", "\n", " Returns:\n", " Stage 2 Generator Model for StackGAN.\n", " \"\"\"\n", " input_layer1 = Input(shape=(1024,))\n", " input_images = Input(shape=(64, 64, 3))\n", "\n", " # Conditioning Augmentation\n", " ca = Dense(256)(input_layer1)\n", " mls = LeakyReLU(alpha=0.2)(ca)\n", " c = Lambda(conditioning_augmentation)(mls)\n", "\n", " # Downsampling block\n", " x = ZeroPadding2D(padding=(1,1))(input_images)\n", " x = Conv2D(128, kernel_size=(3,3), strides=1, use_bias=False,\n", " kernel_initializer='he_uniform')(x)\n", " x = ReLU()(x)\n", "\n", " x = ZeroPadding2D(padding=(1,1))(x)\n", " x = Conv2D(256, kernel_size=(4,4), strides=2, use_bias=False,\n", " kernel_initializer='he_uniform')(x)\n", " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n", " x = ReLU()(x)\n", "\n", " x = ZeroPadding2D(padding=(1,1))(x)\n", " x = Conv2D(512, kernel_size=(4,4), strides=2, use_bias=False,\n", " kernel_initializer='he_uniform')(x)\n", " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n", " x = ReLU()(x)\n", "\n", " # Concatenate text conditioning block with the encoded image\n", " concat = concat_along_dims([c, x])\n", "\n", " # Residual Blocks\n", " x = ZeroPadding2D(padding=(1,1))(concat)\n", " x = Conv2D(512, kernel_size=(3,3), use_bias=False, kernel_initializer='he_uniform')(x)\n", " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n", " x = ReLU()(x)\n", "\n", " x = residual_block(x)\n", " x = residual_block(x)\n", " x = residual_block(x)\n", " x = residual_block(x)\n", " \n", " # Upsampling Blocks\n", " x = UpSamplingBlock(x, 512)\n", " x = UpSamplingBlock(x, 256)\n", " x = UpSamplingBlock(x, 128)\n", " x = UpSamplingBlock(x, 64)\n", "\n", " x = Conv2D(3, kernel_size=(3,3), padding='same', use_bias=False, kernel_initializer='he_uniform')(x)\n", " x = Activation('tanh')(x)\n", " \n", " stage2_gen = Model(inputs=[input_layer1, input_images], outputs=[x, mls])\n", " return stage2_gen" ] }, { "cell_type": "code", "execution_count": 30, "id": "76c876db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_3\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_11 (InputLayer) [(None, 64, 64, 3)] 0 [] \n", " \n", " zero_padding2d (ZeroPadding2D) (None, 66, 66, 3) 0 ['input_11[0][0]'] \n", " \n", " conv2d_14 (Conv2D) (None, 64, 64, 128) 3456 ['zero_padding2d[0][0]'] \n", " \n", " re_lu_5 (ReLU) (None, 64, 64, 128) 0 ['conv2d_14[0][0]'] \n", " \n", " zero_padding2d_1 (ZeroPadding2 (None, 66, 66, 128) 0 ['re_lu_5[0][0]'] \n", " D) \n", " \n", " input_10 (InputLayer) [(None, 1024)] 0 [] \n", " \n", " conv2d_15 (Conv2D) (None, 32, 32, 256) 524288 ['zero_padding2d_1[0][0]'] \n", " \n", " dense_3 (Dense) (None, 256) 262400 ['input_10[0][0]'] \n", " \n", " batch_normalization_11 (BatchN (None, 32, 32, 256) 1024 ['conv2d_15[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_10 (LeakyReLU) (None, 256) 0 ['dense_3[0][0]'] \n", " \n", " re_lu_6 (ReLU) (None, 32, 32, 256) 0 ['batch_normalization_11[0][0]'] \n", " \n", " lambda_1 (Lambda) (None, 128) 0 ['leaky_re_lu_10[0][0]'] \n", " \n", " zero_padding2d_2 (ZeroPadding2 (None, 34, 34, 256) 0 ['re_lu_6[0][0]'] \n", " D) \n", " \n", " tf.expand_dims (TFOpLambda) (None, 1, 128) 0 ['lambda_1[0][0]'] \n", " \n", " conv2d_16 (Conv2D) (None, 16, 16, 512) 2097152 ['zero_padding2d_2[0][0]'] \n", " \n", " tf.expand_dims_1 (TFOpLambda) (None, 1, 1, 128) 0 ['tf.expand_dims[0][0]'] \n", " \n", " batch_normalization_12 (BatchN (None, 16, 16, 512) 2048 ['conv2d_16[0][0]'] \n", " ormalization) \n", " \n", " tf.tile (TFOpLambda) (None, 16, 16, 128) 0 ['tf.expand_dims_1[0][0]'] \n", " \n", " re_lu_7 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_12[0][0]'] \n", " \n", " tf.concat (TFOpLambda) (None, 16, 16, 640) 0 ['tf.tile[0][0]', \n", " 're_lu_7[0][0]'] \n", " \n", " zero_padding2d_3 (ZeroPadding2 (None, 18, 18, 640) 0 ['tf.concat[0][0]'] \n", " D) \n", " \n", " conv2d_17 (Conv2D) (None, 16, 16, 512) 2949120 ['zero_padding2d_3[0][0]'] \n", " \n", " batch_normalization_13 (BatchN (None, 16, 16, 512) 2048 ['conv2d_17[0][0]'] \n", " ormalization) \n", " \n", " re_lu_8 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_13[0][0]'] \n", " \n", " conv2d_18 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_8[0][0]'] \n", " \n", " batch_normalization_14 (BatchN (None, 16, 16, 512) 2048 ['conv2d_18[0][0]'] \n", " ormalization) \n", " \n", " re_lu_9 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_14[0][0]'] \n", " \n", " conv2d_19 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_9[0][0]'] \n", " \n", " batch_normalization_15 (BatchN (None, 16, 16, 512) 2048 ['conv2d_19[0][0]'] \n", " ormalization) \n", " \n", " add (Add) (None, 16, 16, 512) 0 ['batch_normalization_15[0][0]', \n", " 're_lu_8[0][0]'] \n", " \n", " re_lu_10 (ReLU) (None, 16, 16, 512) 0 ['add[0][0]'] \n", " \n", " conv2d_20 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_10[0][0]'] \n", " \n", " batch_normalization_16 (BatchN (None, 16, 16, 512) 2048 ['conv2d_20[0][0]'] \n", " ormalization) \n", " \n", " re_lu_11 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_16[0][0]'] \n", " \n", " conv2d_21 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_11[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n", " batch_normalization_17 (BatchN (None, 16, 16, 512) 2048 ['conv2d_21[0][0]'] \n", " ormalization) \n", " \n", " add_1 (Add) (None, 16, 16, 512) 0 ['batch_normalization_17[0][0]', \n", " 're_lu_10[0][0]'] \n", " \n", " re_lu_12 (ReLU) (None, 16, 16, 512) 0 ['add_1[0][0]'] \n", " \n", " conv2d_22 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_12[0][0]'] \n", " \n", " batch_normalization_18 (BatchN (None, 16, 16, 512) 2048 ['conv2d_22[0][0]'] \n", " ormalization) \n", " \n", " re_lu_13 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_18[0][0]'] \n", " \n", " conv2d_23 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_13[0][0]'] \n", " \n", " batch_normalization_19 (BatchN (None, 16, 16, 512) 2048 ['conv2d_23[0][0]'] \n", " ormalization) \n", " \n", " add_2 (Add) (None, 16, 16, 512) 0 ['batch_normalization_19[0][0]', \n", " 're_lu_12[0][0]'] \n", " \n", " re_lu_14 (ReLU) (None, 16, 16, 512) 0 ['add_2[0][0]'] \n", " \n", " conv2d_24 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_14[0][0]'] \n", " \n", " batch_normalization_20 (BatchN (None, 16, 16, 512) 2048 ['conv2d_24[0][0]'] \n", " ormalization) \n", " \n", " re_lu_15 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_20[0][0]'] \n", " \n", " conv2d_25 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_15[0][0]'] \n", " \n", " batch_normalization_21 (BatchN (None, 16, 16, 512) 2048 ['conv2d_25[0][0]'] \n", " ormalization) \n", " \n", " add_3 (Add) (None, 16, 16, 512) 0 ['batch_normalization_21[0][0]', \n", " 're_lu_14[0][0]'] \n", " \n", " re_lu_16 (ReLU) (None, 16, 16, 512) 0 ['add_3[0][0]'] \n", " \n", " up_sampling2d_4 (UpSampling2D) (None, 32, 32, 512) 0 ['re_lu_16[0][0]'] \n", " \n", " conv2d_26 (Conv2D) (None, 32, 32, 512) 2359296 ['up_sampling2d_4[0][0]'] \n", " \n", " batch_normalization_22 (BatchN (None, 32, 32, 512) 2048 ['conv2d_26[0][0]'] \n", " ormalization) \n", " \n", " re_lu_17 (ReLU) (None, 32, 32, 512) 0 ['batch_normalization_22[0][0]'] \n", " \n", " up_sampling2d_5 (UpSampling2D) (None, 64, 64, 512) 0 ['re_lu_17[0][0]'] \n", " \n", " conv2d_27 (Conv2D) (None, 64, 64, 256) 1179648 ['up_sampling2d_5[0][0]'] \n", " \n", " batch_normalization_23 (BatchN (None, 64, 64, 256) 1024 ['conv2d_27[0][0]'] \n", " ormalization) \n", " \n", " re_lu_18 (ReLU) (None, 64, 64, 256) 0 ['batch_normalization_23[0][0]'] \n", " \n", " up_sampling2d_6 (UpSampling2D) (None, 128, 128, 25 0 ['re_lu_18[0][0]'] \n", " 6) \n", " \n", " conv2d_28 (Conv2D) (None, 128, 128, 12 294912 ['up_sampling2d_6[0][0]'] \n", " 8) \n", " \n", " batch_normalization_24 (BatchN (None, 128, 128, 12 512 ['conv2d_28[0][0]'] \n", " ormalization) 8) \n", " \n", " re_lu_19 (ReLU) (None, 128, 128, 12 0 ['batch_normalization_24[0][0]'] \n", " 8) \n", " \n", " up_sampling2d_7 (UpSampling2D) (None, 256, 256, 12 0 ['re_lu_19[0][0]'] \n", " 8) \n", " \n", " conv2d_29 (Conv2D) (None, 256, 256, 64 73728 ['up_sampling2d_7[0][0]'] \n", " ) \n", " \n", " batch_normalization_25 (BatchN (None, 256, 256, 64 256 ['conv2d_29[0][0]'] \n", " ormalization) ) \n", " \n", " re_lu_20 (ReLU) (None, 256, 256, 64 0 ['batch_normalization_25[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ) \n", " \n", " conv2d_30 (Conv2D) (None, 256, 256, 3) 1728 ['re_lu_20[0][0]'] \n", " \n", " activation_2 (Activation) (None, 256, 256, 3) 0 ['conv2d_30[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 28,645,440\n", "Trainable params: 28,632,768\n", "Non-trainable params: 12,672\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "generator_stage2 = build_stage2_generator()\n", "generator_stage2.summary()" ] }, { "cell_type": "code", "execution_count": 31, "id": "41de758a", "metadata": {}, "outputs": [], "source": [ "############################################################\n", "# Stage 2 Discriminator Network\n", "############################################################\n", "\n", "def build_stage2_discriminator():\n", " \"\"\"Builds the Stage 2 Discriminator that uses the 256x256 resolution images from the generator\n", " and the compressed and spatially replicated embeddings.\n", "\n", " Returns:\n", " Stage 2 Discriminator Model for StackGAN.\n", " \"\"\"\n", " input_layer1 = Input(shape=(256, 256, 3))\n", "\n", " x = Conv2D(64, kernel_size=(4,4), padding='same', strides=2, use_bias=False,\n", " kernel_initializer='he_uniform')(input_layer1)\n", " x = LeakyReLU(alpha=0.2)(x)\n", "\n", " x = ConvBlock(x, 128)\n", " x = ConvBlock(x, 256)\n", " x = ConvBlock(x, 512)\n", " x = ConvBlock(x, 1024)\n", " x = ConvBlock(x, 2048)\n", " x = ConvBlock(x, 1024, (1,1), 1)\n", " x = ConvBlock(x, 512, (1,1), 1, False)\n", "\n", " x1 = ConvBlock(x, 128, (1,1), 1)\n", " x1 = ConvBlock(x1, 128, (3,3), 1)\n", " x1 = ConvBlock(x1, 512, (3,3), 1, False)\n", "\n", " x2 = add([x, x1])\n", " x2 = LeakyReLU(alpha=0.2)(x2)\n", "\n", " # Concatenate compressed and spatially replicated embedding\n", " input_layer2 = Input(shape=(4, 4, 128))\n", " concat = concatenate([x2, input_layer2])\n", "\n", " x3 = Conv2D(512, kernel_size=(1,1), strides=1, padding='same', kernel_initializer='he_uniform')(concat)\n", " x3 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x3)\n", " x3 = LeakyReLU(alpha=0.2)(x3)\n", "\n", " # Flatten and add a FC layer\n", " x3 = Flatten()(x3)\n", " x3 = Dense(1)(x3)\n", " x3 = Activation('sigmoid')(x3)\n", "\n", " stage2_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x3])\n", " return stage2_dis" ] }, { "cell_type": "code", "execution_count": 32, "id": "7dbcbc4e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_4\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_12 (InputLayer) [(None, 256, 256, 3 0 [] \n", " )] \n", " \n", " conv2d_31 (Conv2D) (None, 128, 128, 64 3072 ['input_12[0][0]'] \n", " ) \n", " \n", " leaky_re_lu_11 (LeakyReLU) (None, 128, 128, 64 0 ['conv2d_31[0][0]'] \n", " ) \n", " \n", " conv2d_32 (Conv2D) (None, 64, 64, 128) 131072 ['leaky_re_lu_11[0][0]'] \n", " \n", " batch_normalization_26 (BatchN (None, 64, 64, 128) 512 ['conv2d_32[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_12 (LeakyReLU) (None, 64, 64, 128) 0 ['batch_normalization_26[0][0]'] \n", " \n", " conv2d_33 (Conv2D) (None, 32, 32, 256) 524288 ['leaky_re_lu_12[0][0]'] \n", " \n", " batch_normalization_27 (BatchN (None, 32, 32, 256) 1024 ['conv2d_33[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_13 (LeakyReLU) (None, 32, 32, 256) 0 ['batch_normalization_27[0][0]'] \n", " \n", " conv2d_34 (Conv2D) (None, 16, 16, 512) 2097152 ['leaky_re_lu_13[0][0]'] \n", " \n", " batch_normalization_28 (BatchN (None, 16, 16, 512) 2048 ['conv2d_34[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_14 (LeakyReLU) (None, 16, 16, 512) 0 ['batch_normalization_28[0][0]'] \n", " \n", " conv2d_35 (Conv2D) (None, 8, 8, 1024) 8388608 ['leaky_re_lu_14[0][0]'] \n", " \n", " batch_normalization_29 (BatchN (None, 8, 8, 1024) 4096 ['conv2d_35[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_15 (LeakyReLU) (None, 8, 8, 1024) 0 ['batch_normalization_29[0][0]'] \n", " \n", " conv2d_36 (Conv2D) (None, 4, 4, 2048) 33554432 ['leaky_re_lu_15[0][0]'] \n", " \n", " batch_normalization_30 (BatchN (None, 4, 4, 2048) 8192 ['conv2d_36[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_16 (LeakyReLU) (None, 4, 4, 2048) 0 ['batch_normalization_30[0][0]'] \n", " \n", " conv2d_37 (Conv2D) (None, 4, 4, 1024) 2097152 ['leaky_re_lu_16[0][0]'] \n", " \n", " batch_normalization_31 (BatchN (None, 4, 4, 1024) 4096 ['conv2d_37[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_17 (LeakyReLU) (None, 4, 4, 1024) 0 ['batch_normalization_31[0][0]'] \n", " \n", " conv2d_38 (Conv2D) (None, 4, 4, 512) 524288 ['leaky_re_lu_17[0][0]'] \n", " \n", " batch_normalization_32 (BatchN (None, 4, 4, 512) 2048 ['conv2d_38[0][0]'] \n", " ormalization) \n", " \n", " conv2d_39 (Conv2D) (None, 4, 4, 128) 65536 ['batch_normalization_32[0][0]'] \n", " \n", " batch_normalization_33 (BatchN (None, 4, 4, 128) 512 ['conv2d_39[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_18 (LeakyReLU) (None, 4, 4, 128) 0 ['batch_normalization_33[0][0]'] \n", " \n", " conv2d_40 (Conv2D) (None, 4, 4, 128) 147456 ['leaky_re_lu_18[0][0]'] \n", " \n", " batch_normalization_34 (BatchN (None, 4, 4, 128) 512 ['conv2d_40[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_19 (LeakyReLU) (None, 4, 4, 128) 0 ['batch_normalization_34[0][0]'] \n", " \n", " conv2d_41 (Conv2D) (None, 4, 4, 512) 589824 ['leaky_re_lu_19[0][0]'] \n", " \n", " batch_normalization_35 (BatchN (None, 4, 4, 512) 2048 ['conv2d_41[0][0]'] \n", " ormalization) \n", " \n", " add_4 (Add) (None, 4, 4, 512) 0 ['batch_normalization_32[0][0]', \n", " 'batch_normalization_35[0][0]'] \n", " \n", " leaky_re_lu_20 (LeakyReLU) (None, 4, 4, 512) 0 ['add_4[0][0]'] \n", " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " input_13 (InputLayer) [(None, 4, 4, 128)] 0 [] \n", " \n", " concatenate_2 (Concatenate) (None, 4, 4, 640) 0 ['leaky_re_lu_20[0][0]', \n", " 'input_13[0][0]'] \n", " \n", " conv2d_42 (Conv2D) (None, 4, 4, 512) 328192 ['concatenate_2[0][0]'] \n", " \n", " batch_normalization_36 (BatchN (None, 4, 4, 512) 2048 ['conv2d_42[0][0]'] \n", " ormalization) \n", " \n", " leaky_re_lu_21 (LeakyReLU) (None, 4, 4, 512) 0 ['batch_normalization_36[0][0]'] \n", " \n", " flatten_1 (Flatten) (None, 8192) 0 ['leaky_re_lu_21[0][0]'] \n", " \n", " dense_4 (Dense) (None, 1) 8193 ['flatten_1[0][0]'] \n", " \n", " activation_3 (Activation) (None, 1) 0 ['dense_4[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 48,486,401\n", "Trainable params: 48,472,833\n", "Non-trainable params: 13,568\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "discriminator_stage2 = build_stage2_discriminator()\n", "discriminator_stage2.summary()" ] }, { "cell_type": "code", "execution_count": 33, "id": "7131179e", "metadata": {}, "outputs": [], "source": [ "############################################################\n", "# Stage 2 Adversarial Model\n", "############################################################\n", "\n", "def stage2_adversarial_network(stage2_disc, stage2_gen, stage1_gen):\n", " \"\"\"Stage 2 Adversarial Network.\n", "\n", " Args:\n", " stage2_disc: Stage 2 Discriminator Model.\n", " stage2_gen: Stage 2 Generator Model.\n", " stage1_gen: Stage 1 Generator Model.\n", "\n", " Returns:\n", " Stage 2 Adversarial network.\n", " \"\"\"\n", " conditioned_embedding = Input(shape=(1024, ))\n", " latent_space = Input(shape=(100, ))\n", " compressed_replicated = Input(shape=(4, 4, 128))\n", " \n", " #the discriminator is trained separately and stage1_gen already trained, and this is the reason why we freeze its layers by setting the property trainable=false\n", " input_images, ca = stage1_gen([conditioned_embedding, latent_space])\n", " stage2_disc.trainable = False\n", " stage1_gen.trainable = False\n", "\n", " images, ca2 = stage2_gen([conditioned_embedding, input_images])\n", " probability = stage2_disc([images, compressed_replicated])\n", "\n", " return Model(inputs=[conditioned_embedding, latent_space, compressed_replicated],\n", " outputs=[probability, ca2])" ] }, { "cell_type": "code", "execution_count": 34, "id": "a324bec8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_5\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_14 (InputLayer) [(None, 1024)] 0 [] \n", " \n", " input_15 (InputLayer) [(None, 100)] 0 [] \n", " \n", " model (Functional) [(None, 64, 64, 3), 10270400 ['input_14[0][0]', \n", " (None, 256)] 'input_15[0][0]'] \n", " \n", " model_3 (Functional) [(None, 256, 256, 3 28645440 ['input_14[0][0]', \n", " ), 'model[1][0]'] \n", " (None, 256)] \n", " \n", " input_16 (InputLayer) [(None, 4, 4, 128)] 0 [] \n", " \n", " model_4 (Functional) (None, 1) 48486401 ['model_3[0][0]', \n", " 'input_16[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 87,402,241\n", "Trainable params: 28,632,768\n", "Non-trainable params: 58,769,473\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "adversarial_stage2 = stage2_adversarial_network(discriminator_stage2, generator_stage2, generator)\n", "adversarial_stage2.summary()" ] }, { "cell_type": "code", "execution_count": 35, "id": "75ce4927", "metadata": {}, "outputs": [], "source": [ "class StackGanStage2(object):\n", " \"\"\"StackGAN Stage 2 class.\n", "\n", " Args:\n", " epochs: Number of epochs\n", " z_dim: Latent space dimensions\n", " batch_size: Batch Size\n", " enable_function: If True, training function is decorated with tf.function\n", " stage2_generator_lr: Learning rate for stage 2 generator\n", " stage2_discriminator_lr: Learning rate for stage 2 discriminator\n", " \"\"\"\n", " def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage2_generator_lr=0.0002, stage2_discriminator_lr=0.0002):\n", " self.epochs = epochs\n", " self.z_dim = z_dim\n", " self.enable_function = enable_function\n", " self.stage1_generator_lr = stage2_generator_lr\n", " self.stage1_discriminator_lr = stage2_discriminator_lr\n", " self.low_image_size = 64\n", " self.high_image_size = 256\n", " self.conditioning_dim = 128\n", " self.batch_size = batch_size\n", " self.stage2_generator_optimizer = Adam(lr=stage2_generator_lr, beta_1=0.5, beta_2=0.999)\n", " self.stage2_discriminator_optimizer = Adam(lr=stage2_discriminator_lr, beta_1=0.5, beta_2=0.999)\n", " self.stage1_generator = build_stage1_generator()\n", " self.stage1_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)\n", " self.stage1_generator.load_weights('weights/stage1_gen.h5')\n", " self.stage2_generator = build_stage2_generator()\n", " self.stage2_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)\n", "\n", " self.stage2_discriminator = build_stage2_discriminator()\n", " self.stage2_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage2_discriminator_optimizer)\n", "\n", " self.ca_network = build_ca_network()\n", " self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')\n", "\n", " self.embedding_compressor = build_embedding_compressor()\n", " self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')\n", "\n", " self.stage2_adversarial = stage2_adversarial_network(self.stage2_discriminator, self.stage2_generator, self.stage1_generator)\n", " self.stage2_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage2_generator_optimizer)\t\n", "\n", " self.checkpoint2 = tf.train.Checkpoint(\n", " generator_optimizer=self.stage2_generator_optimizer,\n", " discriminator_optimizer=self.stage2_discriminator_optimizer,\n", " generator=self.stage2_generator,\n", " discriminator=self.stage2_discriminator,\n", " generator1=self.stage1_generator)\n", "\n", " def visualize_stage2(self):\n", " \"\"\"Running Tensorboard visualizations.\n", " \"\"\"\n", " tb = TensorBoard(log_dir=\"logs/\".format(time.time()))\n", " tb.set_model(self.stage2_generator)\n", " tb.set_model(self.stage2_discriminator)\n", "\n", " def train_stage2(self):\n", " \"\"\"Trains Stage 2 StackGAN.\n", " \"\"\"\n", " x_high_train, y_high_train, high_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n", " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(256, 256))\n", "\n", " x_high_test, y_high_test, high_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n", " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(256, 256))\n", "\n", " x_low_train, y_low_train, low_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n", " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))\n", "\n", " x_low_test, y_low_test, low_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n", " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))\n", "\n", " real = np.ones((self.batch_size, 1), dtype='float') * 0.9\n", " fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1\n", "\n", " for epoch in range(self.epochs):\n", " print(f'Epoch: {epoch}')\n", "\n", " gen_loss = []\n", " disc_loss = []\n", "\n", " num_batches = int(x_high_train.shape[0] / self.batch_size)\n", "\n", " for i in range(num_batches):\n", "\n", " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n", " embedding_text = high_train_embeds[i * self.batch_size:(i + 1) * self.batch_size]\n", " compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)\n", " compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, self.conditioning_dim))\n", " compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))\n", "\n", " image_batch = x_high_train[i * self.batch_size:(i+1) * self.batch_size]\n", " image_batch = (image_batch - 127.5) / 127.5\n", " \n", " low_res_fakes, _ = self.stage1_generator.predict([embedding_text, latent_space], verbose=3)\n", " high_res_fakes, _ = self.stage2_generator.predict([embedding_text, low_res_fakes], verbose=3)\n", "\n", " discriminator_loss = self.stage2_discriminator.train_on_batch([image_batch, compressed_embedding],\n", " np.reshape(real, (self.batch_size, 1)))\n", "\n", " discriminator_loss_gen = self.stage2_discriminator.train_on_batch([high_res_fakes, compressed_embedding],\n", " np.reshape(fake, (self.batch_size, 1)))\n", "\n", " discriminator_loss_fake = self.stage2_discriminator.train_on_batch([image_batch[:(self.batch_size-1)], compressed_embedding[1:]],\n", " np.reshape(fake[1:], (self.batch_size - 1, 1)))\n", "\n", " d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_fake))\n", " disc_loss.append(d_loss)\n", "\n", " print(f'Discriminator Loss: {d_loss}')\n", "\n", " g_loss = self.stage2_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],\n", " [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])\n", " gen_loss.append(g_loss)\n", "\n", " print(f'Generator Loss: {g_loss}')\n", "\n", " if epoch % 5 == 0:\n", " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n", " embedding_batch = high_test_embeds[0 : self.batch_size]\n", "\n", " low_fake_images, _ = self.stage1_generator.predict([embedding_batch, latent_space], verbose=3)\n", " high_fake_images, _ = self.stage2_generator.predict([embedding_batch, low_fake_images], verbose=3)\n", "\n", " for i, image in enumerate(high_fake_images[:10]):\n", " save_image(image, f'results_stage2/gen_{epoch}_{i}.png')\n", "\n", " if epoch % 10 == 0:\n", " self.stage2_generator.save_weights('weights/stage2_gen.h5')\n", " self.stage2_discriminator.save_weights(\"weights/stage2_disc.h5\")\n", " self.ca_network.save_weights('weights/stage2_ca.h5')\n", " self.embedding_compressor.save_weights('weights/stage2_embco.h5')\n", " self.stage2_adversarial.save_weights('weights/stage2_adv.h5')\n", "\n", " self.stage2_generator.save_weights('weights/stage2_gen.h5')\n", " self.stage2_discriminator.save_weights(\"weights/stage2_disc.h5\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0a91a164", "metadata": {}, "outputs": [], "source": [ "stage2 = StackGanStage2()\n", "stage2.train_stage2()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }