{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "enhance-me-train.ipynb", "provenance": [], "collapsed_sections": [], "machine_shape": "hm", "authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k", "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1JryaVhtBHij", "outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46" }, "source": [ "!git clone https://github.com/soumik12345/enhance-me -b mirnet\n", "!pip install -qqq wandb streamlit" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Cloning into 'enhance-me'...\n", "remote: Enumerating objects: 89, done.\u001b[K\n", "remote: Counting objects: 100% (89/89), done.\u001b[K\n", "remote: Compressing objects: 100% (61/61), done.\u001b[K\n", "remote: Total 89 (delta 43), reused 63 (delta 23), pack-reused 0\u001b[K\n", "Unpacking objects: 100% (89/89), done.\n", "\u001b[K |████████████████████████████████| 1.7 MB 8.2 MB/s \n", "\u001b[K |████████████████████████████████| 9.1 MB 33.4 MB/s \n", "\u001b[K |████████████████████████████████| 140 kB 74.7 MB/s \n", "\u001b[K |████████████████████████████████| 97 kB 8.6 MB/s \n", "\u001b[K |████████████████████████████████| 180 kB 83.6 MB/s \n", "\u001b[K |████████████████████████████████| 63 kB 2.1 MB/s \n", "\u001b[K |████████████████████████████████| 4.3 MB 83.4 MB/s \n", "\u001b[K |████████████████████████████████| 178 kB 68.0 MB/s \n", "\u001b[K |████████████████████████████████| 76 kB 6.0 MB/s \n", "\u001b[K |████████████████████████████████| 111 kB 81.8 MB/s \n", "\u001b[K |████████████████████████████████| 125 kB 86.7 MB/s \n", "\u001b[K |████████████████████████████████| 791 kB 67.2 MB/s \n", "\u001b[K |████████████████████████████████| 374 kB 83.4 MB/s \n", "\u001b[?25h Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for pympler (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for blinker (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "jupyter-console 5.2.0 requires prompt-toolkit<2.0.0,>=1.0.0, but you have prompt-toolkit 3.0.23 which is incompatible.\n", "google-colab 1.0.0 requires ipykernel~=4.10, but you have ipykernel 6.5.1 which is incompatible.\n", "google-colab 1.0.0 requires ipython~=5.5.0, but you have ipython 7.30.0 which is incompatible.\u001b[0m\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "G_c4VtXWHR5l" }, "source": [ "import sys\n", "sys.path.append(\"./enhance-me\")\n", "\n", "from PIL import Image\n", "from enhance_me import commons\n", "from enhance_me.mirnet import MIRNet" ], "execution_count": 2, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ZpBHbYaMIqP_" }, "source": [ "#@title MIRNet Train Configs\n", "\n", "experiment_name = 'lol_dataset_256' #@param {type:\"string\"}\n", "image_size = 128 #@param {type:\"integer\"}\n", "dataset_label = 'lol' #@param [\"lol\"]\n", "apply_random_horizontal_flip = True #@param {type:\"boolean\"}\n", "apply_random_vertical_flip = True #@param {type:\"boolean\"}\n", "apply_random_rotation = True #@param {type:\"boolean\"}\n", "wandb_api_key = '' #@param {type:\"string\"}\n", "val_split = 0.1 #@param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n", "batch_size = 4 #@param {type:\"integer\"}\n", "num_recursive_residual_groups = 3 #@param {type:\"slider\", min:1, max:5, step:1}\n", "num_multi_scale_residual_blocks = 2 #@param {type:\"slider\", min:1, max:5, step:1}\n", "learning_rate = 1e-4 #@param {type:\"number\"}\n", "epsilon = 1e-3 #@param {type:\"number\"}\n", "epochs = 50 #@param {type:\"slider\", min:10, max:100, step:5}" ], "execution_count": 3, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "id": "IVRoedqBIMuH", "outputId": "53ca5beb-871a-4ec3-b757-173e09a15331" }, "source": [ "mirnet = MIRNet(\n", " experiment_name=experiment_name,\n", " wandb_api_key=None if wandb_api_key == '' else wandb_api_key\n", ")" ], "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m19soumik-rakshit96\u001b[0m (use `wandb login --relogin` to force relogin)\n" ] }, { "output_type": "display_data", "data": { "text/html": [ "\n", " Syncing run lol_dataset_256 to Weights & Biases (docs).
\n", "\n", " " ], "text/plain": [ "" ] }, "metadata": {} } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O66Iwzx8IsGh", "outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2" }, "source": [ "mirnet.build_datasets(\n", " image_size=image_size,\n", " dataset_label=dataset_label,\n", " apply_random_horizontal_flip=apply_random_horizontal_flip,\n", " apply_random_vertical_flip=apply_random_vertical_flip,\n", " apply_random_rotation=apply_random_rotation,\n", " val_split=val_split,\n", " batch_size=batch_size\n", ")" ], "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Downloading data from https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip\n", "347176960/347171015 [==============================] - 13s 0us/step\n", "347185152/347171015 [==============================] - 13s 0us/step\n", "Number of train data points: 436\n", "Number of validation data points: 49\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "tsfKrBCsL_Bb" }, "source": [ "mirnet.build_model(\n", " num_recursive_residual_groups=num_recursive_residual_groups,\n", " num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n", " learning_rate=learning_rate,\n", " epsilon=epsilon\n", ")" ], "execution_count": 6, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "y3L9wlpkNziL", "outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc" }, "source": [ "history = mirnet.train(epochs=epochs)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py:1410: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", " layer_config = serialize_layer_fn(layer)\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/50\n", " 66/218 [========>.....................] - ETA: 2:25 - loss: 0.1721 - peak_signal_noise_ratio: 63.2555" ] } ] }, { "cell_type": "code", "metadata": { "colab": { "background_save": true }, "id": "daFKbgBkiyzc" }, "source": [ "for index, low_image_file in enumerate(mirnet.test_low_images):\n", " original_image = Image.open(low_image_file)\n", " enhanced_image = mirnet.infer(original_image)\n", " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n", " commons.plot_results(\n", " [original_image, ground_truth, ground_truth],\n", " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n", " (18, 18)\n", " )" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "dO-IbNQHkB3R" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }