File size: 127,520 Bytes
9bf40bb |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from __future__ import print_function, division\n",
"\n",
"from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization\n",
"from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, BatchNormalization, Activation, ZeroPadding2D\n",
"from keras.layers.advanced_activations import LeakyReLU\n",
"from keras.layers.convolutional import UpSampling2D, Conv2D\n",
"from keras.models import Sequential, Model\n",
"from keras.optimizers import Adam\n",
"\n",
"from keras.preprocessing.image import img_to_array\n",
"from keras.preprocessing.image import load_img\n",
"\n",
"from sklearn.utils import shuffle\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import datetime\n",
"import natsort\n",
"import scipy\n",
"import sys\n",
"import os\n",
"import cv2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper Function"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def load_filename(path):\n",
" dirFiles = os.listdir(path)\n",
" for i, file in enumerate(dirFiles):\n",
" dirFiles[i] = path + file\n",
" return natsort.natsorted(dirFiles ,reverse=False)\n",
"\n",
"# load all images in a directory into memory\n",
"def load_images(list_path, size=(256, 256)):\n",
" img_list = list()\n",
" # enumerate filenames in directory, assume all are images\n",
" for filename in list_path:\n",
" # load and resize the image\n",
" pixels = load_img(filename, target_size=size)\n",
" # convert to numpy array\n",
" pixels = img_to_array(pixels)\n",
" pixels = (pixels - 127.5) / 127.5\n",
" img_list.append(pixels)\n",
" return np.asarray(img_list)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# select a batch of random samples, returns images and target\n",
"def generate_real_samples(dataset, n_samples, patch_shape):\n",
" # unpack dataset\n",
" trainA, trainB = dataset\n",
"\n",
" # choose random instances\n",
" ix = np.random.randint(0, trainA.shape[0], n_samples)\n",
" \n",
" # retrieve selected images\n",
" X1, X2 = trainA[ix], trainB[ix]\n",
" \n",
" # generate 'real' class labels (1)\n",
" y = np.ones((n_samples, patch_shape, patch_shape, 1))\n",
" \n",
" return [X1, X2], y\n",
"\n",
"# generate a batch of images, returns images and targets\n",
"def generate_fake_samples(g_model, samples, patch_shape):\n",
" # generate fake instance\n",
" X = g_model.predict(samples)\n",
" \n",
" # create 'fake' class labels (0)\n",
" y = np.zeros((len(X), patch_shape, patch_shape, 1))\n",
" \n",
" return X, y"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# generate samples and save as a plot and save the model\n",
"def summarize_performance(step, g_model, d_model, dataset, target_dir='', n_samples=3):\n",
" if target_dir and not os.path.exists(target_dir):\n",
" os.mkdir(target_dir)\n",
" # select a sample of input images\n",
" [X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)\n",
" # generate a batch of fake samples\n",
" X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)\n",
" # scale all pixels from [-1,1] to [0,1]\n",
" X_realA = (X_realA + 1) / 2.0\n",
" X_realB = (X_realB + 1) / 2.0\n",
" X_fakeB = (X_fakeB + 1) / 2.0\n",
" # plot real source images\n",
" for i in range(n_samples):\n",
" plt.subplot(3, n_samples, 1 + i)\n",
" plt.axis('off')\n",
" plt.imshow(X_realA[i])\n",
" # plot generated target image\n",
" for i in range(n_samples):\n",
" plt.subplot(3, n_samples, 1 + n_samples + i)\n",
" plt.axis('off')\n",
" plt.imshow(X_fakeB[i])\n",
" # plot real target image\n",
" for i in range(n_samples):\n",
" plt.subplot(3, n_samples, 1 + n_samples*2 + i)\n",
" plt.axis('off')\n",
" plt.imshow(X_realB[i])\n",
" # save plot to file\n",
" filename1 = 'plot_%06d.png' % (step+1)\n",
" plt.savefig(target_dir + filename1)\n",
" plt.close()\n",
" # save the generator model\n",
" g_model.save(target_dir + 'g_model.h5')\n",
" \n",
" # save the discriminator model\n",
" d_model.save(target_dir + 'd_model.h5')\n",
" \n",
" print('>Saved: %s and %s' % (filename1, 'g_model & d_model'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def generator(img_shape):\n",
" def conv2d(layer_in, n_filter, norm=True):\n",
" d = Conv2D(n_filter, kernel_size=4, strides=2, padding='same')(layer_in)\n",
" d = LeakyReLU(0.2)(d)\n",
" if norm:\n",
" d = InstanceNormalization()(d)\n",
" return d\n",
" \n",
" def deconv2d(layer_in, skip_in, n_filter, dropout=0.5):\n",
" d = UpSampling2D(size=2)(layer_in)\n",
" d = Conv2D(n_filter, kernel_size=4, strides=1, padding='same', activation='relu')(d)\n",
" if dropout:\n",
" d = Dropout(dropout)(d)\n",
" d = InstanceNormalization()(d)\n",
" d = Concatenate()([d, skip_in])\n",
" return d\n",
" \n",
" # Input Layer\n",
" in_img = Input(shape=img_shape)\n",
" \n",
" # Downsampling\n",
" d1 = conv2d(in_img, 64, norm=False)\n",
" d2 = conv2d(d1, 128)\n",
" d3 = conv2d(d2, 256)\n",
" d4 = conv2d(d3, 512)\n",
" d5 = conv2d(d4, 512)\n",
" d6 = conv2d(d5, 512)\n",
" d7 = conv2d(d6, 512)\n",
" \n",
" # Upsampling\n",
" u1 = deconv2d(d7, d6, 512)\n",
" u2 = deconv2d(u1, d5, 512)\n",
" u3 = deconv2d(u2, d4, 512)\n",
" u4 = deconv2d(u3, d3, 256, dropout=0)\n",
" u5 = deconv2d(u4, d2, 128, dropout=0)\n",
" u6 = deconv2d(u5, d1, 64, dropout=0)\n",
" u7 = UpSampling2D(size=2)(u6)\n",
" \n",
" out_img = Conv2D(3, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)\n",
" \n",
" return Model(in_img, out_img, name='generator')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def discriminator(img_shape):\n",
" def d_layer(layer_in, n_filter, norm=True):\n",
" d = Conv2D(n_filter, kernel_size=4, strides=2, padding='same')(layer_in)\n",
" d = LeakyReLU(0.2)(d)\n",
" if norm:\n",
" d = InstanceNormalization()(d)\n",
" return d\n",
" \n",
" in_src_img = Input(shape=img_shape)\n",
" in_target_img = Input(shape=img_shape)\n",
" \n",
" merged = Concatenate()([in_src_img, in_target_img])\n",
" \n",
" d1 = d_layer(merged, 64, norm=False)\n",
" d2 = d_layer(d1, 128)\n",
" d3 = d_layer(d1, 256)\n",
" d4 = d_layer(d1, 512)\n",
"\n",
" out = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)\n",
" \n",
" return Model([in_src_img, in_target_img], out, name='discriminator')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GAN"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def GAN(g_model, d_model, img_shape):\n",
" d_model.trainable = False\n",
" in_img = Input(shape=img_shape)\n",
" gen_out = g_model(in_img)\n",
" dis_out = d_model([in_img, gen_out])\n",
" model = Model(in_img, [dis_out, gen_out], name='GAN')\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train GAN model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def train(d_model, g_model, gan_model, data, target_dir, n_epochs=100, n_batch=16):\n",
" # determine the output square shape of the discriminator\n",
" n_patch = d_model.output_shape[1]\n",
" \n",
" blue_photo = data[0]\n",
" blue_sketch = data[1]\n",
" \n",
" for i in range(n_epochs):\n",
" print(' ========== Epoch', i+1, '========== ')\n",
" \n",
" blue_photo, blue_sketch = shuffle(blue_photo, blue_sketch)\n",
"\n",
" for j in range(int(len(blue_photo)/n_batch)):\n",
" \n",
" start = int(j*n_batch)\n",
" end = int(min(len(blue_photo), (j*n_batch)+n_batch))\n",
" \n",
" dataset = [load_images(blue_photo[start:end]), load_images(blue_sketch[start:end])]\n",
"\n",
" # select a batch of real samples\n",
" [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)\n",
" \n",
" # generate a batch of fake samples\n",
" X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)\n",
" \n",
" # update discriminator for real samples\n",
" d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)\n",
" \n",
" # update discriminator for generated samples\n",
" d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)\n",
" \n",
" d_loss = 0.5 * np.add(d_loss1, d_loss2)\n",
" \n",
" # update the generator\n",
" g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])\n",
" \n",
" # summarize performance\n",
" print('Batch : %d, D Loss : %.3f | G Loss : %.3f' % (j+1, d_loss, g_loss))\n",
" \n",
" # summarize model performance\n",
"# if (i+1) % 10 == 0:\n",
" summarize_performance(i, g_model, d_model, dataset, target_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loss Function"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import keras.backend as K\n",
"from keras.losses import mean_absolute_error\n",
"\n",
"def pixel_loss(y_true, y_pred):\n",
" return K.mean(K.abs(y_true - y_pred))\n",
"\n",
"def contextual_loss (y_true, y_pred):\n",
" a = tf.image.rgb_to_grayscale(tf.slice(\n",
" y_pred, \n",
" [0,0,0,0], \n",
" [16, 256, 256, 3]))\n",
" \n",
" b = tf.image.rgb_to_grayscale(tf.slice(\n",
" y_true, \n",
" [0,0,0,0], \n",
" [16, 256, 256, 3]))\n",
" \n",
" y_pred = tf.divide(tf.add(tf.reshape(a, [tf.shape(a)[0], -1]), 1), 2)\n",
" y_true = tf.divide(tf.add(tf.reshape(b, [tf.shape(b)[0], -1]), 1), 2)\n",
" \n",
"# tf.assert_rank(y_true,2)\n",
"# tf.assert_rank(y_pred,2)\n",
" \n",
" p_shape = tf.shape(y_true)\n",
" q_shape = tf.shape(y_pred)\n",
"# tf.assert_equal(p_shape, q_shape)\n",
" \n",
" # normalize sum to 1\n",
" p_ = tf.divide(y_true, tf.tile(tf.expand_dims(tf.reduce_sum(y_true, axis=1), 1), [1,p_shape[1]]))\n",
" q_ = tf.divide(y_pred, tf.tile(tf.expand_dims(tf.reduce_sum(y_pred, axis=1), 1), [1,p_shape[1]]))\n",
" \n",
" return tf.reduce_sum(tf.multiply(p_, tf.log(tf.divide(p_, q_))), axis=1)\n",
"\n",
"def total_loss (y_true, y_pred):\n",
"\n",
" px_loss = pixel_loss(y_true, y_pred)\n",
"\n",
" ctx_loss = contextual_loss(y_true, y_pred)\n",
" \n",
" return (0.2 * px_loss) + (0.8 * ctx_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# dataset path\n",
"b_photo_path = 'Dataset/Augmented photo/'\n",
"b_sketch_path = 'Dataset/Augmented sketch/'\n",
"\n",
"blue_photo = load_filename(b_photo_path)\n",
"blue_sketch = load_filename(b_sketch_path)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x204a34c4320>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(cv2.cvtColor(cv2.imread(blue_photo[1102]).astype('uint8'), cv2.COLOR_BGR2RGB))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x204a35a7a20>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(cv2.cvtColor(cv2.imread(blue_sketch[1102]).astype('uint8'), cv2.COLOR_BGR2RGB))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define GAN Model"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"img_shape = (256, 256, 3)\n",
"\n",
"d_model = discriminator(img_shape)\n",
"\n",
"g_model = generator(img_shape)\n",
"\n",
"gan_model = GAN(g_model, d_model, img_shape)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"GAN\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_4 (InputLayer) (None, 256, 256, 3) 0 \n",
"__________________________________________________________________________________________________\n",
"generator (Model) (None, 256, 256, 3) 41825691 input_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"discriminator (Model) (None, 64, 64, 1) 539203 input_4[0][0] \n",
" generator[1][0] \n",
"==================================================================================================\n",
"Total params: 42,364,894\n",
"Trainable params: 41,825,691\n",
"Non-trainable params: 539,203\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"gan_model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"opt = Adam(lr=2e-4, beta_1=0.5)\n",
"\n",
"d_model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])\n",
"gan_model.compile(loss=['binary_crossentropy', total_loss], optimizer=opt, loss_weights=[1,100])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Training"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ========== Epoch 1 ========== \n",
"WARNING:tensorflow:From c:\\users\\user\\anaconda3\\envs\\tf-gpu-1\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n",
"\n",
"WARNING:tensorflow:From c:\\users\\user\\anaconda3\\envs\\tf-gpu-1\\lib\\site-packages\\tensorflow\\python\\ops\\math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
"Batch : 1, D Loss : 2.262 | G Loss : 42.058\n",
"Batch : 2, D Loss : 1.777 | G Loss : 21.263\n",
"Batch : 3, D Loss : 2.124 | G Loss : 16.144\n",
"Batch : 4, D Loss : 1.989 | G Loss : 14.232\n",
"Batch : 5, D Loss : 2.531 | G Loss : 13.015\n",
"Batch : 6, D Loss : 4.515 | G Loss : 13.539\n",
"Batch : 7, D Loss : 2.256 | G Loss : 12.942\n",
"Batch : 8, D Loss : 2.218 | G Loss : 11.719\n",
"Batch : 9, D Loss : 4.467 | G Loss : 12.478\n",
"Batch : 10, D Loss : 2.330 | G Loss : 10.870\n",
"Batch : 11, D Loss : 2.607 | G Loss : 11.588\n",
"Batch : 12, D Loss : 3.373 | G Loss : 11.687\n",
"Batch : 13, D Loss : 2.681 | G Loss : 11.481\n",
"Batch : 14, D Loss : 2.573 | G Loss : 10.598\n",
"Batch : 15, D Loss : 2.629 | G Loss : 11.233\n",
"Batch : 16, D Loss : 4.183 | G Loss : 11.090\n",
"Batch : 17, D Loss : 2.189 | G Loss : 11.511\n",
"Batch : 18, D Loss : 2.335 | G Loss : 10.908\n",
"Batch : 19, D Loss : 4.519 | G Loss : 12.389\n",
"Batch : 20, D Loss : 2.727 | G Loss : 11.845\n",
"Batch : 21, D Loss : 2.600 | G Loss : 11.249\n",
"Batch : 22, D Loss : 2.516 | G Loss : 10.602\n",
"Batch : 23, D Loss : 2.363 | G Loss : 10.622\n",
"Batch : 24, D Loss : 4.347 | G Loss : 10.624\n",
"Batch : 25, D Loss : 2.597 | G Loss : 8.816\n",
"Batch : 26, D Loss : 1.898 | G Loss : 11.320\n",
"Batch : 27, D Loss : 4.080 | G Loss : 11.499\n",
"Batch : 28, D Loss : 2.402 | G Loss : 7.653\n",
"Batch : 29, D Loss : 2.225 | G Loss : 7.833\n",
"Batch : 30, D Loss : 2.476 | G Loss : 7.711\n",
"Batch : 31, D Loss : 2.050 | G Loss : 6.253\n",
"Batch : 32, D Loss : 2.581 | G Loss : 7.128\n",
"Batch : 33, D Loss : 2.724 | G Loss : 7.293\n",
"Batch : 34, D Loss : 2.261 | G Loss : 7.426\n",
"Batch : 35, D Loss : 2.261 | G Loss : 6.352\n",
"Batch : 36, D Loss : 2.333 | G Loss : 6.632\n",
"Batch : 37, D Loss : 2.447 | G Loss : 6.757\n",
"Batch : 38, D Loss : 2.398 | G Loss : 6.884\n",
"Batch : 39, D Loss : 2.483 | G Loss : 6.786\n",
"Batch : 40, D Loss : 2.418 | G Loss : 6.549\n",
"Batch : 41, D Loss : 2.003 | G Loss : 7.247\n",
"Batch : 42, D Loss : 2.395 | G Loss : 6.834\n",
"Batch : 43, D Loss : 2.195 | G Loss : 6.685\n",
"Batch : 44, D Loss : 2.534 | G Loss : 6.311\n",
"Batch : 45, D Loss : 2.350 | G Loss : 6.295\n",
"Batch : 46, D Loss : 2.208 | G Loss : 6.204\n",
"Batch : 47, D Loss : 2.295 | G Loss : 6.627\n",
"Batch : 48, D Loss : 2.279 | G Loss : 5.951\n",
"Batch : 49, D Loss : 2.420 | G Loss : 6.254\n",
"Batch : 50, D Loss : 2.112 | G Loss : 6.072\n",
"Batch : 51, D Loss : 2.130 | G Loss : 6.149\n",
"Batch : 52, D Loss : 1.803 | G Loss : 6.464\n",
"Batch : 53, D Loss : 2.267 | G Loss : 6.190\n",
"Batch : 54, D Loss : 2.437 | G Loss : 7.174\n",
"Batch : 55, D Loss : 2.271 | G Loss : 6.602\n",
"Batch : 56, D Loss : 1.843 | G Loss : 4.610\n",
"Batch : 57, D Loss : 2.380 | G Loss : 6.033\n",
"Batch : 58, D Loss : 2.313 | G Loss : 6.456\n",
"Batch : 59, D Loss : 2.274 | G Loss : 6.569\n",
"Batch : 60, D Loss : 2.274 | G Loss : 6.242\n",
"Batch : 61, D Loss : 2.166 | G Loss : 6.124\n",
"Batch : 62, D Loss : 2.320 | G Loss : 6.085\n",
"Batch : 63, D Loss : 2.126 | G Loss : 6.737\n",
"Batch : 64, D Loss : 2.252 | G Loss : 6.092\n",
"Batch : 65, D Loss : 2.046 | G Loss : 5.496\n",
"Batch : 66, D Loss : 2.390 | G Loss : 6.025\n",
"Batch : 67, D Loss : 2.049 | G Loss : 5.764\n",
"Batch : 68, D Loss : 2.294 | G Loss : 5.353\n",
"Batch : 69, D Loss : 1.950 | G Loss : 5.273\n",
"Batch : 70, D Loss : 2.556 | G Loss : 6.479\n",
"Batch : 71, D Loss : 2.027 | G Loss : 6.512\n",
"Batch : 72, D Loss : 1.907 | G Loss : 5.601\n",
"Batch : 73, D Loss : 2.565 | G Loss : 7.373\n",
"Batch : 74, D Loss : 2.228 | G Loss : 8.712\n",
"Batch : 75, D Loss : 2.384 | G Loss : 7.407\n",
"Batch : 76, D Loss : 2.579 | G Loss : 7.034\n",
"Batch : 77, D Loss : 2.062 | G Loss : 6.939\n",
"Batch : 78, D Loss : 2.608 | G Loss : 6.352\n",
"Batch : 79, D Loss : 2.226 | G Loss : 5.870\n",
"Batch : 80, D Loss : 2.528 | G Loss : 6.487\n",
"Batch : 81, D Loss : 2.210 | G Loss : 5.880\n",
"Batch : 82, D Loss : 2.317 | G Loss : 5.851\n",
"Batch : 83, D Loss : 1.921 | G Loss : 5.078\n",
"Batch : 84, D Loss : 2.509 | G Loss : 6.861\n",
"Batch : 85, D Loss : 2.451 | G Loss : 6.088\n",
"Batch : 86, D Loss : 2.200 | G Loss : 5.315\n",
"Batch : 87, D Loss : 2.164 | G Loss : 6.233\n",
"Batch : 88, D Loss : 2.308 | G Loss : 5.804\n",
"Batch : 89, D Loss : 2.092 | G Loss : 5.313\n",
"Batch : 90, D Loss : 2.312 | G Loss : 5.361\n",
"Batch : 91, D Loss : 2.358 | G Loss : 5.635\n",
"Batch : 92, D Loss : 2.243 | G Loss : 5.745\n",
"Batch : 93, D Loss : 2.389 | G Loss : 6.473\n",
"Batch : 94, D Loss : 2.225 | G Loss : 5.830\n",
"Batch : 95, D Loss : 2.139 | G Loss : 5.547\n",
"Batch : 96, D Loss : 2.051 | G Loss : 5.265\n",
"Batch : 97, D Loss : 2.164 | G Loss : 5.894\n",
"Batch : 98, D Loss : 2.172 | G Loss : 5.499\n",
"Batch : 99, D Loss : 2.356 | G Loss : 5.996\n",
"Batch : 100, D Loss : 2.187 | G Loss : 5.309\n",
"Batch : 101, D Loss : 1.820 | G Loss : 5.354\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-17-2899ecd49915>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0md_model\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mg_model\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgan_model\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mblue_sketch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mblue_photo\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Pixel[02]_Context[08]/'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn_epochs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m100\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn_batch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m16\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32m<ipython-input-8-cd00650cfdb5>\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(d_model, g_model, gan_model, data, target_dir, n_epochs, n_batch)\u001b[0m\n\u001b[0;32m 28\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[1;31m# update discriminator for generated samples\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 30\u001b[1;33m \u001b[0md_loss2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0md_model\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain_on_batch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mX_realA\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_fakeB\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_fake\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 31\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 32\u001b[0m \u001b[0md_loss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0.5\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0md_loss1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0md_loss2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\users\\user\\anaconda3\\envs\\tf-gpu-1\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mtrain_on_batch\u001b[1;34m(self, x, y, sample_weight, class_weight, reset_metrics)\u001b[0m\n\u001b[0;32m 1512\u001b[0m \u001b[0mins\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0msample_weights\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_make_train_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1514\u001b[1;33m \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mins\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1515\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mreset_metrics\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mc:\\users\\user\\anaconda3\\envs\\tf-gpu-1\\lib\\site-packages\\tensorflow\\python\\keras\\backend.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, inputs)\u001b[0m\n\u001b[0;32m 3290\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3291\u001b[0m fetched = self._callable_fn(*array_vals,\n\u001b[1;32m-> 3292\u001b[1;33m run_metadata=self.run_metadata)\n\u001b[0m\u001b[0;32m 3293\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_fetch_callbacks\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfetched\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m-\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3294\u001b[0m output_structure = nest.pack_sequence_as(\n",
"\u001b[1;32mc:\\users\\user\\anaconda3\\envs\\tf-gpu-1\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1456\u001b[0m ret = tf_session.TF_SessionRunCallable(self._session._session,\n\u001b[0;32m 1457\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1458\u001b[1;33m run_metadata_ptr)\n\u001b[0m\u001b[0;32m 1459\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1460\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"train(d_model, g_model, gan_model, [blue_sketch, blue_photo], 'Models/Pixel[02]_Context[08]/', n_epochs = 100, n_batch=16)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|