{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "enhance-me-train.ipynb",
"provenance": [],
"collapsed_sections": [],
"machine_shape": "hm",
"authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"
"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1JryaVhtBHij",
"outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46"
},
"source": [
"!git clone https://github.com/soumik12345/enhance-me -b mirnet\n",
"!pip install -qqq wandb streamlit"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'enhance-me'...\n",
"remote: Enumerating objects: 89, done.\u001b[K\n",
"remote: Counting objects: 100% (89/89), done.\u001b[K\n",
"remote: Compressing objects: 100% (61/61), done.\u001b[K\n",
"remote: Total 89 (delta 43), reused 63 (delta 23), pack-reused 0\u001b[K\n",
"Unpacking objects: 100% (89/89), done.\n",
"\u001b[K |████████████████████████████████| 1.7 MB 8.2 MB/s \n",
"\u001b[K |████████████████████████████████| 9.1 MB 33.4 MB/s \n",
"\u001b[K |████████████████████████████████| 140 kB 74.7 MB/s \n",
"\u001b[K |████████████████████████████████| 97 kB 8.6 MB/s \n",
"\u001b[K |████████████████████████████████| 180 kB 83.6 MB/s \n",
"\u001b[K |████████████████████████████████| 63 kB 2.1 MB/s \n",
"\u001b[K |████████████████████████████████| 4.3 MB 83.4 MB/s \n",
"\u001b[K |████████████████████████████████| 178 kB 68.0 MB/s \n",
"\u001b[K |████████████████████████████████| 76 kB 6.0 MB/s \n",
"\u001b[K |████████████████████████████████| 111 kB 81.8 MB/s \n",
"\u001b[K |████████████████████████████████| 125 kB 86.7 MB/s \n",
"\u001b[K |████████████████████████████████| 791 kB 67.2 MB/s \n",
"\u001b[K |████████████████████████████████| 374 kB 83.4 MB/s \n",
"\u001b[?25h Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for pympler (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for blinker (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"jupyter-console 5.2.0 requires prompt-toolkit<2.0.0,>=1.0.0, but you have prompt-toolkit 3.0.23 which is incompatible.\n",
"google-colab 1.0.0 requires ipykernel~=4.10, but you have ipykernel 6.5.1 which is incompatible.\n",
"google-colab 1.0.0 requires ipython~=5.5.0, but you have ipython 7.30.0 which is incompatible.\u001b[0m\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "G_c4VtXWHR5l"
},
"source": [
"import sys\n",
"sys.path.append(\"./enhance-me\")\n",
"\n",
"from PIL import Image\n",
"from enhance_me import commons\n",
"from enhance_me.mirnet import MIRNet"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZpBHbYaMIqP_"
},
"source": [
"#@title MIRNet Train Configs\n",
"\n",
"experiment_name = 'lol_dataset_256' #@param {type:\"string\"}\n",
"image_size = 128 #@param {type:\"integer\"}\n",
"dataset_label = 'lol' #@param [\"lol\"]\n",
"apply_random_horizontal_flip = True #@param {type:\"boolean\"}\n",
"apply_random_vertical_flip = True #@param {type:\"boolean\"}\n",
"apply_random_rotation = True #@param {type:\"boolean\"}\n",
"wandb_api_key = '' #@param {type:\"string\"}\n",
"val_split = 0.1 #@param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
"batch_size = 4 #@param {type:\"integer\"}\n",
"num_recursive_residual_groups = 3 #@param {type:\"slider\", min:1, max:5, step:1}\n",
"num_multi_scale_residual_blocks = 2 #@param {type:\"slider\", min:1, max:5, step:1}\n",
"learning_rate = 1e-4 #@param {type:\"number\"}\n",
"epsilon = 1e-3 #@param {type:\"number\"}\n",
"epochs = 50 #@param {type:\"slider\", min:10, max:100, step:5}"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"id": "IVRoedqBIMuH",
"outputId": "53ca5beb-871a-4ec3-b757-173e09a15331"
},
"source": [
"mirnet = MIRNet(\n",
" experiment_name=experiment_name,\n",
" wandb_api_key=None if wandb_api_key == '' else wandb_api_key\n",
")"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m19soumik-rakshit96\u001b[0m (use `wandb login --relogin` to force relogin)\n"
]
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" Syncing run lol_dataset_256 to Weights & Biases (docs).
\n",
"\n",
" "
],
"text/plain": [
""
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "O66Iwzx8IsGh",
"outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2"
},
"source": [
"mirnet.build_datasets(\n",
" image_size=image_size,\n",
" dataset_label=dataset_label,\n",
" apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
" apply_random_vertical_flip=apply_random_vertical_flip,\n",
" apply_random_rotation=apply_random_rotation,\n",
" val_split=val_split,\n",
" batch_size=batch_size\n",
")"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading data from https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip\n",
"347176960/347171015 [==============================] - 13s 0us/step\n",
"347185152/347171015 [==============================] - 13s 0us/step\n",
"Number of train data points: 436\n",
"Number of validation data points: 49\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "tsfKrBCsL_Bb"
},
"source": [
"mirnet.build_model(\n",
" num_recursive_residual_groups=num_recursive_residual_groups,\n",
" num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n",
" learning_rate=learning_rate,\n",
" epsilon=epsilon\n",
")"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "y3L9wlpkNziL",
"outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc"
},
"source": [
"history = mirnet.train(epochs=epochs)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py:1410: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n",
" layer_config = serialize_layer_fn(layer)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/50\n",
" 66/218 [========>.....................] - ETA: 2:25 - loss: 0.1721 - peak_signal_noise_ratio: 63.2555"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"background_save": true
},
"id": "daFKbgBkiyzc"
},
"source": [
"for index, low_image_file in enumerate(mirnet.test_low_images):\n",
" original_image = Image.open(low_image_file)\n",
" enhanced_image = mirnet.infer(original_image)\n",
" ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
" commons.plot_results(\n",
" [original_image, ground_truth, ground_truth],\n",
" [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
" (18, 18)\n",
" )"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dO-IbNQHkB3R"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}